From 1974ac7f9c90a5ab7dc1d0433f97b16825c94377 Mon Sep 17 00:00:00 2001 From: Shijie Sheng Date: Wed, 29 Oct 2025 11:03:38 -0700 Subject: [PATCH 1/2] chore: add ruff format check Signed-off-by: Shijie Sheng --- Makefile | 2 +- cadence/_internal/__init__.py | 2 - cadence/_internal/activity/__init__.py | 6 +- .../_internal/activity/_activity_executor.py | 54 ++-- cadence/_internal/activity/_context.py | 17 +- cadence/_internal/decision_state_machine.py | 189 ++++++------ cadence/_internal/rpc/error.py | 65 +++-- cadence/_internal/rpc/retry.py | 37 ++- cadence/_internal/rpc/yarpc.py | 18 +- cadence/_internal/workflow/context.py | 1 - .../workflow/decision_events_iterator.py | 165 ++++++----- .../_internal/workflow/decisions_helper.py | 7 +- .../workflow/deterministic_event_loop.py | 28 +- .../workflow/history_event_iterator.py | 27 +- cadence/_internal/workflow/workflow_engine.py | 164 ++++++----- cadence/activity.py | 69 +++-- cadence/data_converter.py | 20 +- cadence/error.py | 58 +++- cadence/metrics/__init__.py | 2 +- cadence/metrics/constants.py | 73 +++-- cadence/metrics/metrics.py | 4 - cadence/metrics/prometheus.py | 2 - cadence/sample/__init__.py | 2 +- cadence/sample/client_example.py | 3 +- cadence/sample/grpc_usage_example.py | 79 +++-- cadence/sample/simple_usage_example.py | 60 ++-- cadence/worker/__init__.py | 11 +- cadence/worker/_activity.py | 39 ++- cadence/worker/_base_task_handler.py | 25 +- cadence/worker/_decision_task_handler.py | 157 ++++++---- cadence/worker/_poller.py | 21 +- cadence/worker/_registry.py | 91 +++--- cadence/worker/_types.py | 1 + cadence/worker/_worker.py | 15 +- cadence/workflow.py | 43 ++- scripts/dev.py | 43 ++- scripts/generate_proto.py | 37 ++- .../activity/test_activity_executor.py | 163 +++++++---- tests/cadence/_internal/rpc/test_error.py | 138 ++++++--- tests/cadence/_internal/rpc/test_retry.py | 137 +++++---- .../_internal/test_decision_state_machine.py | 45 ++- .../workflow/test_decision_events_iterator.py | 158 +++++----- .../workflow/test_deterministic_event_loop.py | 10 +- .../workflow/test_history_event_iterator.py | 84 +++--- .../test_workflow_engine_integration.py | 205 ++++++++----- tests/cadence/common_activities.py | 14 +- tests/cadence/data_converter_test.py | 103 ++++--- tests/cadence/metrics/test_metrics.py | 7 +- tests/cadence/metrics/test_prometheus.py | 63 ++-- .../cadence/worker/test_base_task_handler.py | 50 ++-- .../worker/test_decision_task_handler.py | 270 +++++++++++------- .../test_decision_task_handler_integration.py | 157 ++++++---- .../test_decision_worker_integration.py | 140 +++++---- tests/cadence/worker/test_poller.py | 8 +- tests/cadence/worker/test_registry.py | 25 +- .../worker/test_task_handler_integration.py | 183 +++++++----- tests/cadence/worker/test_worker.py | 44 ++- tests/conftest.py | 10 +- tests/integration_tests/conftest.py | 20 +- tests/integration_tests/helper.py | 2 +- tests/integration_tests/test_client.py | 82 ++++-- 61 files changed, 2334 insertions(+), 1421 deletions(-) diff --git a/Makefile b/Makefile index cd70d0a..e5af680 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,7 @@ generate: lint: @echo "Running Ruff linter and fixing lint issues..." uv tool run ruff check --fix + uv tool run ruff format # Run type checker type-check: @@ -52,4 +53,3 @@ help: @echo " make integration-test - Run integration tests" @echo " make clean - Remove generated files and caches" @echo " make help - Show this help message" - diff --git a/cadence/_internal/__init__.py b/cadence/_internal/__init__.py index ba70374..ee11482 100644 --- a/cadence/_internal/__init__.py +++ b/cadence/_internal/__init__.py @@ -6,5 +6,3 @@ """ __all__: list[str] = [] - - diff --git a/cadence/_internal/activity/__init__.py b/cadence/_internal/activity/__init__.py index 073d53c..3c62818 100644 --- a/cadence/_internal/activity/__init__.py +++ b/cadence/_internal/activity/__init__.py @@ -1,8 +1,4 @@ - - -from ._activity_executor import ( - ActivityExecutor -) +from ._activity_executor import ActivityExecutor __all__ = [ "ActivityExecutor", diff --git a/cadence/_internal/activity/_activity_executor.py b/cadence/_internal/activity/_activity_executor.py index 56f717d..6c2a7c2 100644 --- a/cadence/_internal/activity/_activity_executor.py +++ b/cadence/_internal/activity/_activity_executor.py @@ -8,21 +8,33 @@ from cadence._internal.activity._context import _Context, _SyncContext from cadence.activity import ActivityInfo, ActivityDefinition, ExecutionStrategy from cadence.api.v1.common_pb2 import Failure -from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, RespondActivityTaskFailedRequest, \ - RespondActivityTaskCompletedRequest +from cadence.api.v1.service_worker_pb2 import ( + PollForActivityTaskResponse, + RespondActivityTaskFailedRequest, + RespondActivityTaskCompletedRequest, +) from cadence.client import Client _logger = getLogger(__name__) + class ActivityExecutor: - def __init__(self, client: Client, task_list: str, identity: str, max_workers: int, registry: Callable[[str], ActivityDefinition]): + def __init__( + self, + client: Client, + task_list: str, + identity: str, + max_workers: int, + registry: Callable[[str], ActivityDefinition], + ): self._client = client self._data_converter = client.data_converter self._registry = registry self._identity = identity self._task_list = task_list - self._thread_pool = ThreadPoolExecutor(max_workers=max_workers, - thread_name_prefix=f'{task_list}-activity-') + self._thread_pool = ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix=f"{task_list}-activity-" + ) async def execute(self, task: PollForActivityTaskResponse): try: @@ -46,27 +58,33 @@ def _create_context(self, task: PollForActivityTaskResponse) -> _Context: else: return _SyncContext(self._client, info, activity_def, self._thread_pool) - async def _report_failure(self, task: PollForActivityTaskResponse, error: Exception): + async def _report_failure( + self, task: PollForActivityTaskResponse, error: Exception + ): try: - await self._client.worker_stub.RespondActivityTaskFailed(RespondActivityTaskFailedRequest( - task_token=task.task_token, - failure=_to_failure(error), - identity=self._identity, - )) + await self._client.worker_stub.RespondActivityTaskFailed( + RespondActivityTaskFailedRequest( + task_token=task.task_token, + failure=_to_failure(error), + identity=self._identity, + ) + ) except Exception: - _logger.exception('Exception reporting activity failure') + _logger.exception("Exception reporting activity failure") async def _report_success(self, task: PollForActivityTaskResponse, result: Any): as_payload = await self._data_converter.to_data([result]) try: - await self._client.worker_stub.RespondActivityTaskCompleted(RespondActivityTaskCompletedRequest( - task_token=task.task_token, - result=as_payload, - identity=self._identity, - )) + await self._client.worker_stub.RespondActivityTaskCompleted( + RespondActivityTaskCompletedRequest( + task_token=task.task_token, + result=as_payload, + identity=self._identity, + ) + ) except Exception: - _logger.exception('Exception reporting activity complete') + _logger.exception("Exception reporting activity complete") def _create_info(self, task: PollForActivityTaskResponse) -> ActivityInfo: return ActivityInfo( diff --git a/cadence/_internal/activity/_context.py b/cadence/_internal/activity/_context.py index ce2f94b..22f7f85 100644 --- a/cadence/_internal/activity/_context.py +++ b/cadence/_internal/activity/_context.py @@ -8,7 +8,12 @@ class _Context(ActivityContext): - def __init__(self, client: Client, info: ActivityInfo, activity_fn: ActivityDefinition[[Any], Any]): + def __init__( + self, + client: Client, + info: ActivityInfo, + activity_fn: ActivityDefinition[[Any], Any], + ): self._client = client self._info = info self._activity_fn = activity_fn @@ -28,8 +33,15 @@ def client(self) -> Client: def info(self) -> ActivityInfo: return self._info + class _SyncContext(_Context): - def __init__(self, client: Client, info: ActivityInfo, activity_fn: ActivityDefinition[[Any], Any], executor: ThreadPoolExecutor): + def __init__( + self, + client: Client, + info: ActivityInfo, + activity_fn: ActivityDefinition[[Any], Any], + executor: ThreadPoolExecutor, + ): super().__init__(client, info, activity_fn) self._executor = executor @@ -44,4 +56,3 @@ def _run(self, args: list[Any]) -> Any: def client(self) -> Client: raise RuntimeError("client is only supported in async activities") - diff --git a/cadence/_internal/decision_state_machine.py b/cadence/_internal/decision_state_machine.py index 8b721dc..a5b4db9 100644 --- a/cadence/_internal/decision_state_machine.py +++ b/cadence/_internal/decision_state_machine.py @@ -44,7 +44,7 @@ def to_string(cls, state: DecisionState) -> str: class DecisionType(Enum): """Types of decisions that can be made by state machines.""" - + ACTIVITY = 0 CHILD_WORKFLOW = 1 CANCELLATION = 2 @@ -81,13 +81,26 @@ def __str__(self) -> str: @dataclass class StateTransition: """Represents a state transition with associated actions.""" + next_state: Optional[DecisionState] - action: Optional[Callable[['BaseDecisionStateMachine', history.HistoryEvent], None]] = None - condition: Optional[Callable[['BaseDecisionStateMachine', history.HistoryEvent], bool]] = None + action: Optional[ + Callable[["BaseDecisionStateMachine", history.HistoryEvent], None] + ] = None + condition: Optional[ + Callable[["BaseDecisionStateMachine", history.HistoryEvent], bool] + ] = None class TransitionInfo(TypedDict): - type: Literal["initiated", "started", "completion", "canceled", "cancel_initiated", "cancel_failed", "initiation_failed"] + type: Literal[ + "initiated", + "started", + "completion", + "canceled", + "cancel_initiated", + "cancel_failed", + "initiation_failed", + ] decision_type: DecisionType transition: StateTransition @@ -96,153 +109,143 @@ class TransitionInfo(TypedDict): "activity_task_scheduled_event_attributes": { "type": "initiated", "decision_type": DecisionType.ACTIVITY, - "transition": StateTransition( - next_state=DecisionState.INITIATED - ) + "transition": StateTransition(next_state=DecisionState.INITIATED), }, "activity_task_started_event_attributes": { "type": "started", "decision_type": DecisionType.ACTIVITY, - "transition": StateTransition( - next_state=DecisionState.STARTED - ) + "transition": StateTransition(next_state=DecisionState.STARTED), }, "activity_task_completed_event_attributes": { "type": "completion", "decision_type": DecisionType.ACTIVITY, "transition": StateTransition( next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, 'status', DecisionState.COMPLETED) - ) + action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), + ), }, "activity_task_failed_event_attributes": { "type": "completion", "decision_type": DecisionType.ACTIVITY, "transition": StateTransition( next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, 'status', DecisionState.COMPLETED) - ) + action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), + ), }, "activity_task_timed_out_event_attributes": { "type": "completion", "decision_type": DecisionType.ACTIVITY, "transition": StateTransition( next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, 'status', DecisionState.COMPLETED) - ) + action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), + ), + }, + "activity_task_cancel_requested_event_attributes": { + "type": "cancel_initiated", + "decision_type": DecisionType.CANCELLATION, + "transition": StateTransition( + next_state=None, + action=lambda self, event: setattr(self, "_cancel_requested", True), + ), + }, + "activity_task_canceled_event_attributes": { + "type": "canceled", + "decision_type": DecisionType.ACTIVITY, + "transition": StateTransition( + next_state=DecisionState.CANCELED_AFTER_INITIATED + ), + }, + "request_cancel_activity_task_failed_event_attributes": { + "type": "cancel_failed", + "decision_type": DecisionType.CANCELLATION, + "transition": StateTransition( + next_state=None, + action=lambda self, event: setattr(self, "_cancel_emitted", False), + ), }, - "activity_task_cancel_requested_event_attributes": { - "type": "cancel_initiated", - "decision_type": DecisionType.CANCELLATION, - "transition": StateTransition( - next_state=None, - action=lambda self, event: setattr(self, '_cancel_requested', True) - ) - }, - "activity_task_canceled_event_attributes": { - "type": "canceled", - "decision_type": DecisionType.ACTIVITY, - "transition": StateTransition( - next_state=DecisionState.CANCELED_AFTER_INITIATED - ) - }, - "request_cancel_activity_task_failed_event_attributes": { - "type": "cancel_failed", - "decision_type": DecisionType.CANCELLATION, - "transition": StateTransition( - next_state=None, - action=lambda self, event: setattr(self, '_cancel_emitted', False) - ) - }, "timer_started_event_attributes": { "type": "initiated", "decision_type": DecisionType.TIMER, - "transition": StateTransition( - next_state=DecisionState.INITIATED - ) + "transition": StateTransition(next_state=DecisionState.INITIATED), }, "timer_fired_event_attributes": { "type": "completion", "decision_type": DecisionType.TIMER, "transition": StateTransition( next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, 'status', DecisionState.COMPLETED) - ) + action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), + ), }, "timer_canceled_event_attributes": { "type": "canceled", "decision_type": DecisionType.TIMER, "transition": StateTransition( next_state=DecisionState.CANCELED_AFTER_INITIATED - ) + ), + }, + "cancel_timer_failed_event_attributes": { + "type": "cancel_failed", + "decision_type": DecisionType.CANCELLATION, + "transition": StateTransition( + next_state=None, + action=lambda self, event: setattr(self, "_cancel_emitted", False), + ), }, - "cancel_timer_failed_event_attributes": { - "type": "cancel_failed", - "decision_type": DecisionType.CANCELLATION, - "transition": StateTransition( - next_state=None, - action=lambda self, event: setattr(self, '_cancel_emitted', False) - ) - }, "start_child_workflow_execution_initiated_event_attributes": { "type": "initiated", "decision_type": DecisionType.CHILD_WORKFLOW, - "transition": StateTransition( - next_state=DecisionState.INITIATED - ) + "transition": StateTransition(next_state=DecisionState.INITIATED), }, "child_workflow_execution_started_event_attributes": { "type": "started", "decision_type": DecisionType.CHILD_WORKFLOW, - "transition": StateTransition( - next_state=DecisionState.STARTED - ) + "transition": StateTransition(next_state=DecisionState.STARTED), }, "child_workflow_execution_completed_event_attributes": { "type": "completion", "decision_type": DecisionType.CHILD_WORKFLOW, "transition": StateTransition( next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, 'status', DecisionState.COMPLETED) - ) + action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), + ), }, "child_workflow_execution_failed_event_attributes": { "type": "completion", "decision_type": DecisionType.CHILD_WORKFLOW, "transition": StateTransition( next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, 'status', DecisionState.COMPLETED) - ) + action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), + ), }, "child_workflow_execution_timed_out_event_attributes": { "type": "completion", "decision_type": DecisionType.CHILD_WORKFLOW, "transition": StateTransition( next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, 'status', DecisionState.COMPLETED) - ) + action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), + ), }, "child_workflow_execution_canceled_event_attributes": { "type": "canceled", "decision_type": DecisionType.CHILD_WORKFLOW, "transition": StateTransition( next_state=DecisionState.CANCELED_AFTER_INITIATED - ) + ), }, "child_workflow_execution_terminated_event_attributes": { "type": "canceled", "decision_type": DecisionType.CHILD_WORKFLOW, "transition": StateTransition( next_state=DecisionState.CANCELED_AFTER_INITIATED - ) + ), }, "start_child_workflow_execution_failed_event_attributes": { "type": "initiation_failed", "decision_type": DecisionType.CHILD_WORKFLOW, "transition": StateTransition( next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, 'status', DecisionState.COMPLETED) - ) + action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), + ), }, } @@ -253,7 +256,7 @@ class BaseDecisionStateMachine: Subclasses are responsible for mapping workflow history events into state transitions and producing the next set of decisions when queried. """ - + # Common fields that subclasses may use scheduled_event_id: Optional[int] = None started_event_id: Optional[int] = None @@ -317,7 +320,7 @@ def _should_handle_event_by_event_id( # Check if the event ID matches our tracked event ID event_id = getattr(attr, event_id_field, None) tracked_event_id = getattr(self, self._get_event_id_field_name(), None) - + return event_id == tracked_event_id def _default_initiated_action(self, event: history.HistoryEvent) -> None: @@ -332,7 +335,9 @@ def _default_started_action(self, event: history.HistoryEvent) -> None: if hasattr(self, "started_event_id"): self.started_event_id = event.event_id - def _default_completion_action(self, event: history.HistoryEvent, attr_name: str) -> None: + def _default_completion_action( + self, event: history.HistoryEvent, attr_name: str + ) -> None: """Default action for completion events.""" self.status = DecisionState.COMPLETED @@ -357,7 +362,7 @@ def _default_cancel_failed_action(self, event: history.HistoryEvent) -> None: def handle_event(self, event: history.HistoryEvent, event_type: str) -> None: """Generic event handler that uses the global transition map to determine state changes. - + Args: event: The history event to process event_type: The type of event (e.g., 'initiated', 'started', 'completion', etc.) @@ -381,7 +386,7 @@ def _handle_initiated_event(self, event: history.HistoryEvent) -> None: """Handle initiated events using the global transition map.""" attr_name = self._get_initiated_event_attr_name() id_field = self._get_id_field_name() - + if not self._should_handle_event(event, attr_name, id_field): return @@ -410,8 +415,10 @@ def _handle_started_event(self, event: history.HistoryEvent) -> None: event_id_field = "initiated_event_id" else: event_id_field = self._get_event_id_field_name() - - if not self._should_handle_event_by_event_id(event, attr_name, event_id_field): + + if not self._should_handle_event_by_event_id( + event, attr_name, event_id_field + ): return transition_info = decision_state_transition_map.get(attr_name) @@ -433,18 +440,28 @@ def _handle_completion_event(self, event: history.HistoryEvent) -> None: if attr_name == "timer_fired_event_attributes": # Timer completion events use started_event_id event_id_field = "started_event_id" - elif attr_name in ["activity_task_completed_event_attributes", "activity_task_failed_event_attributes", "activity_task_timed_out_event_attributes"]: + elif attr_name in [ + "activity_task_completed_event_attributes", + "activity_task_failed_event_attributes", + "activity_task_timed_out_event_attributes", + ]: # Activity completion events use scheduled_event_id event_id_field = "scheduled_event_id" - elif attr_name in ["child_workflow_execution_completed_event_attributes", "child_workflow_execution_failed_event_attributes", "child_workflow_execution_timed_out_event_attributes"]: + elif attr_name in [ + "child_workflow_execution_completed_event_attributes", + "child_workflow_execution_failed_event_attributes", + "child_workflow_execution_timed_out_event_attributes", + ]: # Child workflow completion events use initiated_event_id event_id_field = "initiated_event_id" else: # Default case event_id_field = self._get_event_id_field_name() - + # Check if this event should be handled by this machine - if self._should_handle_event_by_event_id(event, attr_name, event_id_field): + if self._should_handle_event_by_event_id( + event, attr_name, event_id_field + ): transition_info = decision_state_transition_map.get(attr_name) if transition_info and transition_info["type"] == "completion": transition = transition_info["transition"] @@ -504,15 +521,20 @@ def _handle_canceled_event(self, event: history.HistoryEvent) -> None: elif attr_name == "activity_task_canceled_event_attributes": # Activity canceled events use scheduled_event_id event_id_field = "scheduled_event_id" - elif attr_name in ["child_workflow_execution_canceled_event_attributes", "child_workflow_execution_terminated_event_attributes"]: + elif attr_name in [ + "child_workflow_execution_canceled_event_attributes", + "child_workflow_execution_terminated_event_attributes", + ]: # Child workflow canceled events use initiated_event_id event_id_field = "initiated_event_id" else: # Default case event_id_field = self._get_event_id_field_name() - + # Check if this event should be handled by this machine - if self._should_handle_event_by_event_id(event, attr_name, event_id_field): + if self._should_handle_event_by_event_id( + event, attr_name, event_id_field + ): transition_info = decision_state_transition_map.get(attr_name) if transition_info and transition_info["type"] == "canceled": transition = transition_info["transition"] @@ -539,6 +561,7 @@ def collect_pending_decisions(self) -> List[decision.Decision]: # Activity + @dataclass class ActivityDecisionMachine(BaseDecisionStateMachine): """Tracks lifecycle of a single activity execution by activity_id.""" diff --git a/cadence/_internal/rpc/error.py b/cadence/_internal/rpc/error.py index d2dbd14..c8cf3ea 100644 --- a/cadence/_internal/rpc/error.py +++ b/cadence/_internal/rpc/error.py @@ -1,9 +1,15 @@ from typing import Callable, Any, Optional, Generator, TypeVar import grpc -from google.rpc.status_pb2 import Status # type: ignore -from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails, AioRpcError, UnaryUnaryCall, Metadata -from grpc_status.rpc_status import from_call # type: ignore +from google.rpc.status_pb2 import Status # type: ignore +from grpc.aio import ( + UnaryUnaryClientInterceptor, + ClientCallDetails, + AioRpcError, + UnaryUnaryCall, + Metadata, +) +from grpc_status.rpc_status import from_call # type: ignore from cadence.api.v1 import error_pb2 from cadence import error @@ -19,14 +25,13 @@ # It doesn't have any functions to compose operations together, so our only option is to wrap it. # If the interceptor directly throws an exception other than AioRpcError it breaks GRPC class CadenceErrorUnaryUnaryCall(UnaryUnaryCall[RequestType, ResponseType]): - def __init__(self, wrapped: UnaryUnaryCall[RequestType, ResponseType]): super().__init__() self._wrapped = wrapped def __await__(self) -> Generator[Any, None, ResponseType]: try: - response = yield from self._wrapped.__await__() # type: ResponseType + response = yield from self._wrapped.__await__() # type: ResponseType return response except AioRpcError as e: raise map_error(e) @@ -47,35 +52,32 @@ async def wait_for_connection(self) -> None: await self._wrapped.wait_for_connection() def cancelled(self) -> bool: - return self._wrapped.cancelled() # type: ignore + return self._wrapped.cancelled() # type: ignore def done(self) -> bool: - return self._wrapped.done() # type: ignore + return self._wrapped.done() # type: ignore def time_remaining(self) -> Optional[float]: - return self._wrapped.time_remaining() # type: ignore + return self._wrapped.time_remaining() # type: ignore def cancel(self) -> bool: - return self._wrapped.cancel() # type: ignore + return self._wrapped.cancel() # type: ignore def add_done_callback(self, callback: DoneCallbackType) -> None: self._wrapped.add_done_callback(callback) class CadenceErrorInterceptor(UnaryUnaryClientInterceptor): - async def intercept_unary_unary( self, continuation: Callable[[ClientCallDetails, Any], Any], client_call_details: ClientCallDetails, - request: Any + request: Any, ) -> Any: rpc_call = await continuation(client_call_details, request) return CadenceErrorUnaryUnaryCall(rpc_call) - - def map_error(e: AioRpcError) -> error.CadenceError: status: Status | None = from_call(e) if not status or not status.details: @@ -85,25 +87,51 @@ def map_error(e: AioRpcError) -> error.CadenceError: if details.Is(error_pb2.WorkflowExecutionAlreadyStartedError.DESCRIPTOR): already_started = error_pb2.WorkflowExecutionAlreadyStartedError() details.Unpack(already_started) - return error.WorkflowExecutionAlreadyStartedError(e.details(), e.code(), already_started.start_request_id, already_started.run_id) + return error.WorkflowExecutionAlreadyStartedError( + e.details(), + e.code(), + already_started.start_request_id, + already_started.run_id, + ) elif details.Is(error_pb2.EntityNotExistsError.DESCRIPTOR): not_exists = error_pb2.EntityNotExistsError() details.Unpack(not_exists) - return error.EntityNotExistsError(e.details(), e.code(), not_exists.current_cluster, not_exists.active_cluster, list(not_exists.active_clusters)) + return error.EntityNotExistsError( + e.details(), + e.code(), + not_exists.current_cluster, + not_exists.active_cluster, + list(not_exists.active_clusters), + ) elif details.Is(error_pb2.WorkflowExecutionAlreadyCompletedError.DESCRIPTOR): return error.WorkflowExecutionAlreadyCompletedError(e.details(), e.code()) elif details.Is(error_pb2.DomainNotActiveError.DESCRIPTOR): not_active = error_pb2.DomainNotActiveError() details.Unpack(not_active) - return error.DomainNotActiveError(e.details(), e.code(), not_active.domain, not_active.current_cluster, not_active.active_cluster, list(not_active.active_clusters)) + return error.DomainNotActiveError( + e.details(), + e.code(), + not_active.domain, + not_active.current_cluster, + not_active.active_cluster, + list(not_active.active_clusters), + ) elif details.Is(error_pb2.ClientVersionNotSupportedError.DESCRIPTOR): not_supported = error_pb2.ClientVersionNotSupportedError() details.Unpack(not_supported) - return error.ClientVersionNotSupportedError(e.details(), e.code(), not_supported.feature_version, not_supported.client_impl, not_supported.supported_versions) + return error.ClientVersionNotSupportedError( + e.details(), + e.code(), + not_supported.feature_version, + not_supported.client_impl, + not_supported.supported_versions, + ) elif details.Is(error_pb2.FeatureNotEnabledError.DESCRIPTOR): not_enabled = error_pb2.FeatureNotEnabledError() details.Unpack(not_enabled) - return error.FeatureNotEnabledError(e.details(), e.code(), not_enabled.feature_flag) + return error.FeatureNotEnabledError( + e.details(), e.code(), not_enabled.feature_flag + ) elif details.Is(error_pb2.CancellationAlreadyRequestedError.DESCRIPTOR): return error.CancellationAlreadyRequestedError(e.details(), e.code()) elif details.Is(error_pb2.DomainAlreadyExistsError.DESCRIPTOR): @@ -118,4 +146,3 @@ def map_error(e: AioRpcError) -> error.CadenceError: return error.ServiceBusyError(e.details(), e.code(), service_busy.reason) else: return error.CadenceError(e.details(), e.code()) - diff --git a/cadence/_internal/rpc/retry.py b/cadence/_internal/rpc/retry.py index 7e1f280..dd3fd35 100644 --- a/cadence/_internal/rpc/retry.py +++ b/cadence/_internal/rpc/retry.py @@ -11,9 +11,10 @@ StatusCode.INTERNAL, StatusCode.RESOURCE_EXHAUSTED, StatusCode.ABORTED, - StatusCode.UNAVAILABLE + StatusCode.UNAVAILABLE, } + # No expiration interval, use the GRPC timeout value instead @dataclass class ExponentialRetryPolicy: @@ -22,20 +23,29 @@ class ExponentialRetryPolicy: max_interval: float max_attempts: float - def next_delay(self, attempts: int, elapsed: float, expiration: float) -> float | None: + def next_delay( + self, attempts: int, elapsed: float, expiration: float + ) -> float | None: if elapsed >= expiration: return None if self.max_attempts != 0 and attempts >= self.max_attempts: return None - backoff = min(self.initial_interval * pow(self.backoff_coefficient, attempts-1), self.max_interval) + backoff = min( + self.initial_interval * pow(self.backoff_coefficient, attempts - 1), + self.max_interval, + ) if (elapsed + backoff) >= expiration: return None return backoff -DEFAULT_RETRY_POLICY = ExponentialRetryPolicy(initial_interval=0.02, backoff_coefficient=1.2, max_interval=6, max_attempts=0) -GET_WORKFLOW_HISTORY = b'/uber.cadence.api.v1.WorkflowAPI/GetWorkflowExecutionHistory' + +DEFAULT_RETRY_POLICY = ExponentialRetryPolicy( + initial_interval=0.02, backoff_coefficient=1.2, max_interval=6, max_attempts=0 +) +GET_WORKFLOW_HISTORY = b"/uber.cadence.api.v1.WorkflowAPI/GetWorkflowExecutionHistory" + class RetryInterceptor(UnaryUnaryClientInterceptor): def __init__(self, retry_policy: ExponentialRetryPolicy = DEFAULT_RETRY_POLICY): @@ -46,7 +56,7 @@ async def intercept_unary_unary( self, continuation: Callable[[ClientCallDetails, Any], Any], client_call_details: ClientCallDetails, - request: Any + request: Any, ) -> Any: loop = asyncio.get_running_loop() expiration_interval = client_call_details.timeout @@ -68,7 +78,9 @@ async def intercept_unary_unary( attempts += 1 elapsed = loop.time() - start_time - backoff = self._retry_policy.next_delay(attempts, elapsed, expiration_interval) + backoff = self._retry_policy.next_delay( + attempts, elapsed, expiration_interval + ) if not is_retryable(err, client_call_details) or backoff is None: break @@ -78,10 +90,15 @@ async def intercept_unary_unary( return rpc_call - def is_retryable(err: CadenceError, call_details: ClientCallDetails) -> bool: # Handle requests to the passive side, matching the Go and Java Clients - if call_details.method == GET_WORKFLOW_HISTORY and isinstance(err, EntityNotExistsError): - return err.active_cluster is not None and err.current_cluster is not None and err.active_cluster != err.current_cluster + if call_details.method == GET_WORKFLOW_HISTORY and isinstance( + err, EntityNotExistsError + ): + return ( + err.active_cluster is not None + and err.current_cluster is not None + and err.active_cluster != err.current_cluster + ) return err.code in RETRYABLE_CODES diff --git a/cadence/_internal/rpc/yarpc.py b/cadence/_internal/rpc/yarpc.py index 42f7994..927b7ef 100644 --- a/cadence/_internal/rpc/yarpc.py +++ b/cadence/_internal/rpc/yarpc.py @@ -9,6 +9,7 @@ ENCODING_KEY = "rpc-encoding" ENCODING_PROTO = "proto" + class YarpcMetadataInterceptor(UnaryUnaryClientInterceptor): def __init__(self, service: str, caller: str): self._metadata = Metadata( @@ -18,15 +19,16 @@ def __init__(self, service: str, caller: str): ) async def intercept_unary_unary( - self, - continuation: Callable[[ClientCallDetails, Any], Any], - client_call_details: ClientCallDetails, - request: Any + self, + continuation: Callable[[ClientCallDetails, Any], Any], + client_call_details: ClientCallDetails, + request: Any, ) -> Any: return await continuation(self._replace_details(client_call_details), request) - - def _replace_details(self, client_call_details: ClientCallDetails) -> ClientCallDetails: + def _replace_details( + self, client_call_details: ClientCallDetails + ) -> ClientCallDetails: metadata = client_call_details.metadata if metadata is None: metadata = self._metadata @@ -35,4 +37,6 @@ def _replace_details(self, client_call_details: ClientCallDetails) -> ClientCall # Namedtuple methods start with an underscore to avoid conflicts and aren't actually private # noinspection PyProtectedMember - return client_call_details._replace(metadata=metadata, timeout=client_call_details.timeout or 60.0) + return client_call_details._replace( + metadata=metadata, timeout=client_call_details.timeout or 60.0 + ) diff --git a/cadence/_internal/workflow/context.py b/cadence/_internal/workflow/context.py index 87184d8..e2516ab 100644 --- a/cadence/_internal/workflow/context.py +++ b/cadence/_internal/workflow/context.py @@ -4,7 +4,6 @@ class Context(WorkflowContext): - def __init__(self, client: Client, info: WorkflowInfo): self._client = client self._info = info diff --git a/cadence/_internal/workflow/decision_events_iterator.py b/cadence/_internal/workflow/decision_events_iterator.py index cb0020b..b92b8ca 100644 --- a/cadence/_internal/workflow/decision_events_iterator.py +++ b/cadence/_internal/workflow/decision_events_iterator.py @@ -20,40 +20,41 @@ class DecisionEvents: """ Represents events for a single decision iteration. """ + events: List[HistoryEvent] = field(default_factory=list) markers: List[HistoryEvent] = field(default_factory=list) replay: bool = False replay_current_time_milliseconds: Optional[int] = None next_decision_event_id: Optional[int] = None - + def get_events(self) -> List[HistoryEvent]: """Return all events in this decision iteration.""" return self.events - - + def get_markers(self) -> List[HistoryEvent]: """Return marker events.""" return self.markers - + def is_replay(self) -> bool: """Check if this decision is in replay mode.""" return self.replay - + def get_event_by_id(self, event_id: int) -> Optional[HistoryEvent]: """Retrieve a specific event by ID, returns None if not found.""" for event in self.events: - if hasattr(event, 'event_id') and event.event_id == event_id: + if hasattr(event, "event_id") and event.event_id == event_id: return event return None + class DecisionEventsIterator: """ Iterator for processing decision events from workflow history. - + This is the main class that processes workflow history events and groups them into decision iterations for proper workflow replay and execution. """ - + def __init__(self, decision_task: PollForDecisionTaskResponse, client: Client): self._client = client self._decision_task = decision_task @@ -64,44 +65,51 @@ def __init__(self, decision_task: PollForDecisionTaskResponse, client: Client): self._replay = True self._replay_current_time_milliseconds: Optional[int] = None self._initialized = False - + @staticmethod def _is_decision_task_started(event: HistoryEvent) -> bool: """Check if event is DecisionTaskStarted.""" - return (hasattr(event, 'decision_task_started_event_attributes') and - event.HasField('decision_task_started_event_attributes')) - + return hasattr( + event, "decision_task_started_event_attributes" + ) and event.HasField("decision_task_started_event_attributes") + @staticmethod def _is_decision_task_completed(event: HistoryEvent) -> bool: """Check if event is DecisionTaskCompleted.""" - return (hasattr(event, 'decision_task_completed_event_attributes') and - event.HasField('decision_task_completed_event_attributes')) - + return hasattr( + event, "decision_task_completed_event_attributes" + ) and event.HasField("decision_task_completed_event_attributes") + @staticmethod def _is_decision_task_failed(event: HistoryEvent) -> bool: """Check if event is DecisionTaskFailed.""" - return (hasattr(event, 'decision_task_failed_event_attributes') and - event.HasField('decision_task_failed_event_attributes')) - + return hasattr( + event, "decision_task_failed_event_attributes" + ) and event.HasField("decision_task_failed_event_attributes") + @staticmethod def _is_decision_task_timed_out(event: HistoryEvent) -> bool: """Check if event is DecisionTaskTimedOut.""" - return (hasattr(event, 'decision_task_timed_out_event_attributes') and - event.HasField('decision_task_timed_out_event_attributes')) - + return hasattr( + event, "decision_task_timed_out_event_attributes" + ) and event.HasField("decision_task_timed_out_event_attributes") + @staticmethod def _is_marker_recorded(event: HistoryEvent) -> bool: """Check if event is MarkerRecorded.""" - return (hasattr(event, 'marker_recorded_event_attributes') and - event.HasField('marker_recorded_event_attributes')) - + return hasattr(event, "marker_recorded_event_attributes") and event.HasField( + "marker_recorded_event_attributes" + ) + @staticmethod def _is_decision_task_completion(event: HistoryEvent) -> bool: """Check if event is any kind of decision task completion.""" - return (DecisionEventsIterator._is_decision_task_completed(event) or - DecisionEventsIterator._is_decision_task_failed(event) or - DecisionEventsIterator._is_decision_task_timed_out(event)) - + return ( + DecisionEventsIterator._is_decision_task_completed(event) + or DecisionEventsIterator._is_decision_task_failed(event) + or DecisionEventsIterator._is_decision_task_timed_out(event) + ) + async def _ensure_initialized(self): """Initialize events list using the existing iterate_history_events.""" if not self._initialized: @@ -109,66 +117,70 @@ async def _ensure_initialized(self): events_iterator = iterate_history_events(self._decision_task, self._client) self._events = [event async for event in events_iterator] self._initialized = True - + # Find first decision task started event for i, event in enumerate(self._events): if self._is_decision_task_started(event): self._event_index = i break - + async def has_next_decision_events(self) -> bool: """Check if there are more decision events to process.""" await self._ensure_initialized() - + # Look for the next DecisionTaskStarted event from current position for i in range(self._event_index, len(self._events)): if self._is_decision_task_started(self._events[i]): return True - + return False - + async def next_decision_events(self) -> DecisionEvents: """ Get the next set of decision events. - + This method processes events starting from a DecisionTaskStarted event until the corresponding DecisionTaskCompleted/Failed/TimedOut event. """ await self._ensure_initialized() - + # Find next DecisionTaskStarted event start_index = None for i in range(self._event_index, len(self._events)): if self._is_decision_task_started(self._events[i]): start_index = i break - + if start_index is None: raise StopIteration("No more decision events") - + decision_events = DecisionEvents() decision_events.replay = self._replay - decision_events.replay_current_time_milliseconds = self._replay_current_time_milliseconds + decision_events.replay_current_time_milliseconds = ( + self._replay_current_time_milliseconds + ) decision_events.next_decision_event_id = self._next_decision_event_id - + # Process DecisionTaskStarted event decision_task_started = self._events[start_index] self._decision_task_started_event = decision_task_started decision_events.events.append(decision_task_started) - + # Update replay time if available if decision_task_started.event_time: self._replay_current_time_milliseconds = ( decision_task_started.event_time.seconds * 1000 ) - decision_events.replay_current_time_milliseconds = self._replay_current_time_milliseconds - + decision_events.replay_current_time_milliseconds = ( + self._replay_current_time_milliseconds + ) + # Process subsequent events until we find the corresponding DecisionTask completion current_index = start_index + 1 while current_index < len(self._events): event = self._events[current_index] decision_events.events.append(event) - + # Categorize the event if self._is_marker_recorded(event): decision_events.markers.append(event) @@ -177,18 +189,18 @@ async def next_decision_events(self) -> DecisionEvents: self._process_decision_completion_event(event, decision_events) current_index += 1 # Move past this event break - + current_index += 1 - + # Update the event index for next iteration self._event_index = current_index - + # Update the next decision event ID if decision_events.events: last_event = decision_events.events[-1] - if hasattr(last_event, 'event_id'): + if hasattr(last_event, "event_id"): self._next_decision_event_id = last_event.event_id + 1 - + # Check if this is the last decision events # Set replay to false only if there are no more decision events after this one # Check directly without calling has_next_decision_events to avoid recursion @@ -197,58 +209,63 @@ async def next_decision_events(self) -> DecisionEvents: if self._is_decision_task_started(self._events[i]): has_more = True break - + if not has_more: self._replay = False decision_events.replay = False - + return decision_events - - def _process_decision_completion_event(self, event: HistoryEvent, decision_events: DecisionEvents): + + def _process_decision_completion_event( + self, event: HistoryEvent, decision_events: DecisionEvents + ): """Process the decision completion event and update state.""" - + # Check if we're still in replay mode # This is determined by comparing event IDs with the current decision task's started event ID - if (self._decision_task_started_event and - hasattr(self._decision_task_started_event, 'event_id') and - hasattr(event, 'event_id')): - + if ( + self._decision_task_started_event + and hasattr(self._decision_task_started_event, "event_id") + and hasattr(event, "event_id") + ): # If this completion event ID is >= the current decision task's started event ID, # we're no longer in replay mode - current_task_started_id = getattr( - self._decision_task.started_event_id, 'value', 0 - ) if hasattr(self._decision_task, 'started_event_id') else 0 - + current_task_started_id = ( + getattr(self._decision_task.started_event_id, "value", 0) + if hasattr(self._decision_task, "started_event_id") + else 0 + ) + if event.event_id >= current_task_started_id: self._replay = False decision_events.replay = False - + def get_replay_current_time_milliseconds(self) -> Optional[int]: """Get the current replay time in milliseconds.""" return self._replay_current_time_milliseconds - + def is_replay_mode(self) -> bool: """Check if the iterator is currently in replay mode.""" return self._replay - + def __aiter__(self): return self - + async def __anext__(self) -> DecisionEvents: if not await self.has_next_decision_events(): raise StopAsyncIteration return await self.next_decision_events() - - # Utility functions def is_decision_event(event: HistoryEvent) -> bool: """Check if an event is a decision-related event.""" - return (DecisionEventsIterator._is_decision_task_started(event) or - DecisionEventsIterator._is_decision_task_completed(event) or - DecisionEventsIterator._is_decision_task_failed(event) or - DecisionEventsIterator._is_decision_task_timed_out(event)) + return ( + DecisionEventsIterator._is_decision_task_started(event) + or DecisionEventsIterator._is_decision_task_completed(event) + or DecisionEventsIterator._is_decision_task_failed(event) + or DecisionEventsIterator._is_decision_task_timed_out(event) + ) def is_marker_event(event: HistoryEvent) -> bool: @@ -258,7 +275,7 @@ def is_marker_event(event: HistoryEvent) -> bool: def extract_event_timestamp_millis(event: HistoryEvent) -> Optional[int]: """Extract timestamp from an event in milliseconds.""" - if hasattr(event, 'event_time') and event.HasField('event_time'): - seconds = getattr(event.event_time, 'seconds', 0) + if hasattr(event, "event_time") and event.HasField("event_time"): + seconds = getattr(event.event_time, "seconds", 0) return seconds * 1000 if seconds > 0 else None - return None \ No newline at end of file + return None diff --git a/cadence/_internal/workflow/decisions_helper.py b/cadence/_internal/workflow/decisions_helper.py index 4099150..d92fb73 100644 --- a/cadence/_internal/workflow/decisions_helper.py +++ b/cadence/_internal/workflow/decisions_helper.py @@ -9,7 +9,11 @@ from dataclasses import dataclass from typing import Dict, Optional -from cadence._internal.decision_state_machine import DecisionId, DecisionType, DecisionManager +from cadence._internal.decision_state_machine import ( + DecisionId, + DecisionType, + DecisionManager, +) logger = logging.getLogger(__name__) @@ -232,7 +236,6 @@ def update_decision_completed(self, decision_key: str) -> None: else: logger.warning(f"No tracker found for decision key: {decision_key}") - def _find_decision_by_scheduled_event_id( self, scheduled_event_id: int ) -> Optional[str]: diff --git a/cadence/_internal/workflow/deterministic_event_loop.py b/cadence/_internal/workflow/deterministic_event_loop.py index f0af8e3..8cc5dca 100644 --- a/cadence/_internal/workflow/deterministic_event_loop.py +++ b/cadence/_internal/workflow/deterministic_event_loop.py @@ -11,6 +11,7 @@ _Ts = TypeVarTuple("_Ts") + class DeterministicEventLoop(AbstractEventLoop): """ This is a basic FIFO implementation of event loop that does not allow I/O or timer operations. @@ -20,13 +21,18 @@ class DeterministicEventLoop(AbstractEventLoop): """ def __init__(self): - self._thread_id = None # indicate if the event loop is running + self._thread_id = None # indicate if the event loop is running self._debug = False - self._ready = collections.deque[events.Handle]() + self._ready = collections.deque[events.Handle]() self._stopping = False self._closed = False - def call_soon(self, callback: Callable[[Unpack[_Ts]], object], *args: Unpack[_Ts], context: Context | None = None) -> Handle: + def call_soon( + self, + callback: Callable[[Unpack[_Ts]], object], + *args: Unpack[_Ts], + context: Context | None = None, + ) -> Handle: return self._call_soon(callback, args, context) def _call_soon(self, callback, args, context) -> Handle: @@ -81,7 +87,7 @@ def run_until_complete(self, future): finally: future.remove_done_callback(_run_until_complete_cb) if not future.done(): - raise RuntimeError('Event loop stopped before Future completed.') + raise RuntimeError("Event loop stopped before Future completed.") return future.result() @@ -94,7 +100,9 @@ def create_task(self, coro, **kwargs): # NOTE: eager_start is not supported for deterministic event loop if kwargs.get("eager_start", False): - raise RuntimeError("eager_start in create_task is not supported for deterministic event loop") + raise RuntimeError( + "eager_start in create_task is not supported for deterministic event loop" + ) return tasks.Task(coro, loop=self, **kwargs) @@ -125,17 +133,18 @@ def stop(self): def _check_closed(self): if self._closed: - raise RuntimeError('Event loop is closed') + raise RuntimeError("Event loop is closed") def _check_running(self): if self.is_running(): - raise RuntimeError('This event loop is already running') + raise RuntimeError("This event loop is already running") if events._get_running_loop() is not None: raise RuntimeError( - 'Cannot run the event loop while another loop is running') + "Cannot run the event loop while another loop is running" + ) def is_running(self): - return (self._thread_id is not None) + return self._thread_id is not None def close(self): """Close the event loop. @@ -154,6 +163,7 @@ def is_closed(self): """Returns True if the event loop was closed.""" return self._closed + def _run_until_complete_cb(fut): if not fut.cancelled(): exc = fut.exception() diff --git a/cadence/_internal/workflow/history_event_iterator.py b/cadence/_internal/workflow/history_event_iterator.py index 3d99497..900f8a5 100644 --- a/cadence/_internal/workflow/history_event_iterator.py +++ b/cadence/_internal/workflow/history_event_iterator.py @@ -1,9 +1,14 @@ - from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse -from cadence.api.v1.service_workflow_pb2 import GetWorkflowExecutionHistoryRequest, GetWorkflowExecutionHistoryResponse +from cadence.api.v1.service_workflow_pb2 import ( + GetWorkflowExecutionHistoryRequest, + GetWorkflowExecutionHistoryResponse, +) from cadence.client import Client -async def iterate_history_events(decision_task: PollForDecisionTaskResponse, client: Client): + +async def iterate_history_events( + decision_task: PollForDecisionTaskResponse, client: Client +): PAGE_SIZE = 1000 current_page = decision_task.history.events @@ -15,11 +20,15 @@ async def iterate_history_events(decision_task: PollForDecisionTaskResponse, cli yield event if not next_page_token: break - response: GetWorkflowExecutionHistoryResponse = await client.workflow_stub.GetWorkflowExecutionHistory(GetWorkflowExecutionHistoryRequest( - domain=client.domain, - workflow_execution=workflow_execution, - next_page_token=next_page_token, - page_size=PAGE_SIZE, - )) + response: GetWorkflowExecutionHistoryResponse = ( + await client.workflow_stub.GetWorkflowExecutionHistory( + GetWorkflowExecutionHistoryRequest( + domain=client.domain, + workflow_execution=workflow_execution, + next_page_token=next_page_token, + page_size=PAGE_SIZE, + ) + ) + ) current_page = response.history.events next_page_token = response.next_page_token diff --git a/cadence/_internal/workflow/workflow_engine.py b/cadence/_internal/workflow/workflow_engine.py index 67606b7..627434b 100644 --- a/cadence/_internal/workflow/workflow_engine.py +++ b/cadence/_internal/workflow/workflow_engine.py @@ -19,6 +19,7 @@ class DecisionResult: decisions: list[Decision] + class WorkflowEngine: def __init__(self, info: WorkflowInfo, client: Client, workflow_definition=None): self._context = Context(client, info) @@ -30,7 +31,9 @@ def __init__(self, info: WorkflowInfo, client: Client, workflow_definition=None) self._decisions_helper = DecisionsHelper(self._decision_manager) self._is_workflow_complete = False - async def process_decision(self, decision_task: PollForDecisionTaskResponse) -> DecisionResult: + async def process_decision( + self, decision_task: PollForDecisionTaskResponse + ) -> DecisionResult: """ Process a decision task and generate decisions using DecisionEventsIterator. @@ -52,14 +55,16 @@ async def process_decision(self, decision_task: PollForDecisionTaskResponse) -> "workflow_id": self._context.info().workflow_id, "run_id": self._context.info().workflow_run_id, "started_event_id": decision_task.started_event_id, - "attempt": decision_task.attempt - } + "attempt": decision_task.attempt, + }, ) # Activate workflow context for the entire decision processing with self._context._activate(): # Create DecisionEventsIterator for structured event processing - events_iterator = DecisionEventsIterator(decision_task, self._context.client()) + events_iterator = DecisionEventsIterator( + decision_task, self._context.client() + ) # Process decision events using iterator-driven approach await self._process_decision_events(events_iterator, decision_task) @@ -79,8 +84,8 @@ async def process_decision(self, decision_task: PollForDecisionTaskResponse) -> "run_id": self._context.info().workflow_run_id, "started_event_id": decision_task.started_event_id, "decisions_count": len(decisions), - "replay_mode": self._context.is_replay_mode() - } + "replay_mode": self._context.is_replay_mode(), + }, ) return DecisionResult(decisions=decisions) @@ -95,14 +100,18 @@ async def process_decision(self, decision_task: PollForDecisionTaskResponse) -> "run_id": self._context.info().workflow_run_id, "started_event_id": decision_task.started_event_id, "attempt": decision_task.attempt, - "error_type": type(e).__name__ + "error_type": type(e).__name__, }, - exc_info=True + exc_info=True, ) # Re-raise the exception so the handler can properly handle the failure raise - - async def _process_decision_events(self, events_iterator: DecisionEventsIterator, decision_task: PollForDecisionTaskResponse) -> None: + + async def _process_decision_events( + self, + events_iterator: DecisionEventsIterator, + decision_task: PollForDecisionTaskResponse, + ) -> None: """ Process decision events using the iterator-driven approach similar to Java client. @@ -132,14 +141,16 @@ async def _process_decision_events(self, events_iterator: DecisionEventsIterator "events_count": len(decision_events.get_events()), "markers_count": len(decision_events.get_markers()), "replay_mode": decision_events.is_replay(), - "replay_time": decision_events.replay_current_time_milliseconds - } + "replay_time": decision_events.replay_current_time_milliseconds, + }, ) # Update context with replay information self._context.set_replay_mode(decision_events.is_replay()) if decision_events.replay_current_time_milliseconds: - self._context.set_replay_current_time_milliseconds(decision_events.replay_current_time_milliseconds) + self._context.set_replay_current_time_milliseconds( + decision_events.replay_current_time_milliseconds + ) # Phase 1: Process markers first for deterministic replay for marker_event in decision_events.get_markers(): @@ -148,10 +159,12 @@ async def _process_decision_events(self, events_iterator: DecisionEventsIterator "Processing marker event", extra={ "workflow_id": self._context.info().workflow_id, - "marker_name": getattr(marker_event, 'marker_name', 'unknown'), - "event_id": getattr(marker_event, 'event_id', None), - "replay_mode": self._context.is_replay_mode() - } + "marker_name": getattr( + marker_event, "marker_name", "unknown" + ), + "event_id": getattr(marker_event, "event_id", None), + "replay_mode": self._context.is_replay_mode(), + }, ) # Process through state machines (DecisionsHelper now delegates to DecisionManager) self._decision_manager.handle_history_event(marker_event) @@ -161,11 +174,13 @@ async def _process_decision_events(self, events_iterator: DecisionEventsIterator "Unexpected marker event encountered", extra={ "workflow_id": self._context.info().workflow_id, - "marker_name": getattr(marker_event, 'marker_name', 'unknown'), - "event_id": getattr(marker_event, 'event_id', None), - "error_type": type(e).__name__ + "marker_name": getattr( + marker_event, "marker_name", "unknown" + ), + "event_id": getattr(marker_event, "event_id", None), + "error_type": type(e).__name__, }, - exc_info=True + exc_info=True, ) # Phase 2: Process regular events to update workflow state @@ -175,10 +190,10 @@ async def _process_decision_events(self, events_iterator: DecisionEventsIterator "Processing history event", extra={ "workflow_id": self._context.info().workflow_id, - "event_type": getattr(event, 'event_type', 'unknown'), - "event_id": getattr(event, 'event_id', None), - "replay_mode": self._context.is_replay_mode() - } + "event_type": getattr(event, "event_type", "unknown"), + "event_id": getattr(event, "event_id", None), + "replay_mode": self._context.is_replay_mode(), + }, ) # Process through state machines (DecisionsHelper now delegates to DecisionManager) self._decision_manager.handle_history_event(event) @@ -187,11 +202,11 @@ async def _process_decision_events(self, events_iterator: DecisionEventsIterator "Error processing history event", extra={ "workflow_id": self._context.info().workflow_id, - "event_type": getattr(event, 'event_type', 'unknown'), - "event_id": getattr(event, 'event_id', None), - "error_type": type(e).__name__ + "event_type": getattr(event, "event_type", "unknown"), + "event_id": getattr(event, "event_id", None), + "error_type": type(e).__name__, }, - exc_info=True + exc_info=True, ) # Phase 3: Execute workflow logic if not in replay mode @@ -200,19 +215,24 @@ async def _process_decision_events(self, events_iterator: DecisionEventsIterator # If no decision events were processed but we have history, fall back to direct processing # This handles edge cases where the iterator doesn't find decision events - if not processed_any_decision_events and decision_task.history and hasattr(decision_task.history, 'events'): + if ( + not processed_any_decision_events + and decision_task.history + and hasattr(decision_task.history, "events") + ): logger.debug( "No decision events found by iterator, falling back to direct history processing", extra={ "workflow_id": self._context.info().workflow_id, - "history_events_count": len(decision_task.history.events) if decision_task.history else 0 - } + "history_events_count": len(decision_task.history.events) + if decision_task.history + else 0, + }, ) self._fallback_process_workflow_history(decision_task.history) if not self._is_workflow_complete: await self._execute_workflow_function(decision_task) - def _fallback_process_workflow_history(self, history) -> None: """ Fallback method to process workflow history events directly. @@ -223,15 +243,15 @@ def _fallback_process_workflow_history(self, history) -> None: Args: history: The workflow history from the decision task """ - if not history or not hasattr(history, 'events'): + if not history or not hasattr(history, "events"): return logger.debug( "Processing history events in fallback mode", extra={ "workflow_id": self._context.info().workflow_id, - "events_count": len(history.events) - } + "events_count": len(history.events), + }, ) for event in history.events: @@ -243,14 +263,16 @@ def _fallback_process_workflow_history(self, history) -> None: "Error processing history event in fallback mode", extra={ "workflow_id": self._context.info().workflow_id, - "event_type": getattr(event, 'event_type', 'unknown'), - "event_id": getattr(event, 'event_id', None), - "error_type": type(e).__name__ + "event_type": getattr(event, "event_type", "unknown"), + "event_id": getattr(event, "event_id", None), + "error_type": type(e).__name__, }, - exc_info=True + exc_info=True, ) - - async def _execute_workflow_function(self, decision_task: PollForDecisionTaskResponse) -> None: + + async def _execute_workflow_function( + self, decision_task: PollForDecisionTaskResponse + ) -> None: """ Execute the workflow function to generate new decisions. @@ -267,19 +289,23 @@ async def _execute_workflow_function(self, decision_task: PollForDecisionTaskRes extra={ "workflow_type": self._context.info().workflow_type, "workflow_id": self._context.info().workflow_id, - "run_id": self._context.info().workflow_run_id - } + "run_id": self._context.info().workflow_run_id, + }, ) return # Get the workflow run method from the instance - workflow_func = self._workflow_definition.get_run_method(self._workflow_instance) + workflow_func = self._workflow_definition.get_run_method( + self._workflow_instance + ) # Extract workflow input from history workflow_input = await self._extract_workflow_input(decision_task) # Execute workflow function - result = await self._execute_workflow_function_once(workflow_func, workflow_input) + result = await self._execute_workflow_function_once( + workflow_func, workflow_input + ) # Check if workflow is complete if result is not None: @@ -291,8 +317,8 @@ async def _execute_workflow_function(self, decision_task: PollForDecisionTaskRes "workflow_type": self._context.info().workflow_type, "workflow_id": self._context.info().workflow_id, "run_id": self._context.info().workflow_run_id, - "completion_type": "success" - } + "completion_type": "success", + }, ) except Exception as e: @@ -302,46 +328,54 @@ async def _execute_workflow_function(self, decision_task: PollForDecisionTaskRes "workflow_type": self._context.info().workflow_type, "workflow_id": self._context.info().workflow_id, "run_id": self._context.info().workflow_run_id, - "error_type": type(e).__name__ + "error_type": type(e).__name__, }, - exc_info=True + exc_info=True, ) raise - - async def _extract_workflow_input(self, decision_task: PollForDecisionTaskResponse) -> Any: + + async def _extract_workflow_input( + self, decision_task: PollForDecisionTaskResponse + ) -> Any: """ Extract workflow input from the decision task history. - + Args: decision_task: The decision task containing workflow history - + Returns: The workflow input data, or None if not found """ - if not decision_task.history or not hasattr(decision_task.history, 'events'): + if not decision_task.history or not hasattr(decision_task.history, "events"): logger.warning("No history events found in decision task") return None - + # Look for WorkflowExecutionStarted event for event in decision_task.history.events: - if hasattr(event, 'workflow_execution_started_event_attributes'): + if hasattr(event, "workflow_execution_started_event_attributes"): started_attrs = event.workflow_execution_started_event_attributes - if started_attrs and hasattr(started_attrs, 'input'): + if started_attrs and hasattr(started_attrs, "input"): # Deserialize the input using the client's data converter try: # Use from_data method with a single type hint of None (no type conversion) - input_data_list = await self._context.client().data_converter.from_data(started_attrs.input, [None]) + input_data_list = ( + await self._context.client().data_converter.from_data( + started_attrs.input, [None] + ) + ) input_data = input_data_list[0] if input_data_list else None logger.debug(f"Extracted workflow input: {input_data}") return input_data except Exception as e: logger.warning(f"Failed to deserialize workflow input: {e}") return None - + logger.warning("No WorkflowExecutionStarted event found in history") return None - - async def _execute_workflow_function_once(self, workflow_func: Callable, workflow_input: Any) -> Any: + + async def _execute_workflow_function_once( + self, workflow_func: Callable, workflow_input: Any + ) -> Any: """ Execute the workflow function once (not during replay). @@ -354,13 +388,13 @@ async def _execute_workflow_function_once(self, workflow_func: Callable, workflo """ logger.debug(f"Executing workflow function with input: {workflow_input}") result = workflow_func(workflow_input) - + # If the workflow function is async, await it properly if asyncio.iscoroutine(result): result = await result - + return result - + def _close_event_loop(self) -> None: """ Close the decider's event loop. diff --git a/cadence/activity.py b/cadence/activity.py index 57a9b48..581bea9 100644 --- a/cadence/activity.py +++ b/cadence/activity.py @@ -7,8 +7,19 @@ from enum import Enum from functools import update_wrapper from inspect import signature, Parameter -from typing import Iterator, TypedDict, Unpack, Callable, Type, ParamSpec, TypeVar, Generic, get_type_hints, \ - Any, overload +from typing import ( + Iterator, + TypedDict, + Unpack, + Callable, + Type, + ParamSpec, + TypeVar, + Generic, + get_type_hints, + Any, + overload, +) from cadence import Client @@ -29,27 +40,27 @@ class ActivityInfo: start_to_close_timeout: timedelta attempt: int + def client() -> Client: return ActivityContext.get().client() + def in_activity() -> bool: return ActivityContext.is_set() + def info() -> ActivityInfo: return ActivityContext.get().info() - class ActivityContext(ABC): - _var: ContextVar['ActivityContext'] = ContextVar("activity") + _var: ContextVar["ActivityContext"] = ContextVar("activity") @abstractmethod - def info(self) -> ActivityInfo: - ... + def info(self) -> ActivityInfo: ... @abstractmethod - def client(self) -> Client: - ... + def client(self) -> Client: ... @contextmanager def _activate(self) -> Iterator[None]: @@ -62,7 +73,7 @@ def is_set() -> bool: return ActivityContext._var.get(None) is not None @staticmethod - def get() -> 'ActivityContext': + def get() -> "ActivityContext": return ActivityContext._var.get() @@ -72,18 +83,28 @@ class ActivityParameter: type_hint: Type | None default_value: Any | None + class ExecutionStrategy(Enum): ASYNC = "async" THREAD_POOL = "thread_pool" + class ActivityDefinitionOptions(TypedDict, total=False): name: str -P = ParamSpec('P') -T = TypeVar('T') + +P = ParamSpec("P") +T = TypeVar("T") + class ActivityDefinition(Generic[P, T]): - def __init__(self, wrapped: Callable[P, T], name: str, strategy: ExecutionStrategy, params: list[ActivityParameter]): + def __init__( + self, + wrapped: Callable[P, T], + name: str, + strategy: ExecutionStrategy, + params: list[ActivityParameter], + ): self._wrapped = wrapped self._name = name self._strategy = strategy @@ -106,13 +127,15 @@ def params(self) -> list[ActivityParameter]: return self._params @staticmethod - def wrap(fn: Callable[P, T], opts: ActivityDefinitionOptions) -> 'ActivityDefinition[P, T]': + def wrap( + fn: Callable[P, T], opts: ActivityDefinitionOptions + ) -> "ActivityDefinition[P, T]": name = fn.__qualname__ if "name" in opts and opts["name"]: name = opts["name"] strategy = ExecutionStrategy.THREAD_POOL - if inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn.__call__): # type: ignore + if inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn.__call__): # type: ignore strategy = ExecutionStrategy.ASYNC params = _get_params(fn) @@ -121,16 +144,20 @@ def wrap(fn: Callable[P, T], opts: ActivityDefinitionOptions) -> 'ActivityDefini ActivityDecorator = Callable[[Callable[P, T]], ActivityDefinition[P, T]] + @overload -def defn(fn: Callable[P, T]) -> ActivityDefinition[P, T]: - ... +def defn(fn: Callable[P, T]) -> ActivityDefinition[P, T]: ... + @overload -def defn(**kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator: - ... +def defn(**kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator: ... + -def defn(fn: Callable[P, T] | None = None, **kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator | ActivityDefinition[P, T]: +def defn( + fn: Callable[P, T] | None = None, **kwargs: Unpack[ActivityDefinitionOptions] +) -> ActivityDecorator | ActivityDefinition[P, T]: options = ActivityDefinitionOptions(**kwargs) + def decorator(inner_fn: Callable[P, T]) -> ActivityDefinition[P, T]: return ActivityDefinition.wrap(inner_fn, options) @@ -157,6 +184,8 @@ def _get_params(fn: Callable) -> list[ActivityParameter]: result.append(ActivityParameter(name, type_hint, default)) else: - raise ValueError(f"Parameters must be positional. {name} is {param.kind}, and not valid") + raise ValueError( + f"Parameters must be positional. {name} is {param.kind}, and not valid" + ) return result diff --git a/cadence/data_converter.py b/cadence/data_converter.py index 350848e..16851bf 100644 --- a/cadence/data_converter.py +++ b/cadence/data_converter.py @@ -5,26 +5,30 @@ from json import JSONDecoder from msgspec import json, convert -_SPACE = ' '.encode() +_SPACE = " ".encode() -class DataConverter(Protocol): +class DataConverter(Protocol): @abstractmethod - async def from_data(self, payload: Payload, type_hints: List[Type | None]) -> List[Any]: + async def from_data( + self, payload: Payload, type_hints: List[Type | None] + ) -> List[Any]: raise NotImplementedError() @abstractmethod async def to_data(self, values: List[Any]) -> Payload: raise NotImplementedError() + class DefaultDataConverter(DataConverter): def __init__(self) -> None: self._encoder = json.Encoder() # Need to use std lib decoder in order to decode the custom whitespace delimited data format self._decoder = JSONDecoder(strict=False) - - async def from_data(self, payload: Payload, type_hints: List[Type | None]) -> List[Any]: + async def from_data( + self, payload: Payload, type_hints: List[Type | None] + ) -> List[Any]: if not payload.data: return DefaultDataConverter._convert_into([], type_hints) @@ -32,7 +36,9 @@ async def from_data(self, payload: Payload, type_hints: List[Type | None]) -> Li return self._decode_whitespace_delimited(payload_str, type_hints) - def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type | None]) -> List[Any]: + def _decode_whitespace_delimited( + self, payload: str, type_hints: List[Type | None] + ) -> List[Any]: results: List[Any] = [] start, end = 0, len(payload) while start < end and len(results) < len(type_hints): @@ -66,7 +72,6 @@ def _get_default(type_hint: Type) -> Any: return False return None - async def to_data(self, values: List[Any]) -> Payload: result = bytearray() for index, value in enumerate(values): @@ -75,4 +80,3 @@ async def to_data(self, values: List[Any]) -> Payload: result += _SPACE return Payload(data=bytes(result)) - diff --git a/cadence/error.py b/cadence/error.py index a7ea5fd..1d3be6d 100644 --- a/cadence/error.py +++ b/cadence/error.py @@ -2,63 +2,101 @@ class CadenceError(Exception): - def __init__(self, message: str, code: grpc.StatusCode, *args): super().__init__(message, code, *args) self.code = code + pass class WorkflowExecutionAlreadyStartedError(CadenceError): - - def __init__(self, message: str, code: grpc.StatusCode, start_request_id: str, run_id: str) -> None: + def __init__( + self, message: str, code: grpc.StatusCode, start_request_id: str, run_id: str + ) -> None: super().__init__(message, code, start_request_id, run_id) self.start_request_id = start_request_id self.run_id = run_id -class EntityNotExistsError(CadenceError): - def __init__(self, message: str, code: grpc.StatusCode, current_cluster: str, active_cluster: str, active_clusters: list[str]) -> None: - super().__init__(message, code, current_cluster, active_cluster, active_clusters) +class EntityNotExistsError(CadenceError): + def __init__( + self, + message: str, + code: grpc.StatusCode, + current_cluster: str, + active_cluster: str, + active_clusters: list[str], + ) -> None: + super().__init__( + message, code, current_cluster, active_cluster, active_clusters + ) self.current_cluster = current_cluster self.active_cluster = active_cluster self.active_clusters = active_clusters + class WorkflowExecutionAlreadyCompletedError(CadenceError): pass + class DomainNotActiveError(CadenceError): - def __init__(self, message: str, code: grpc.StatusCode, domain: str, current_cluster: str, active_cluster: str, active_clusters: list[str]) -> None: - super().__init__(message, code, domain, current_cluster, active_cluster, active_clusters) + def __init__( + self, + message: str, + code: grpc.StatusCode, + domain: str, + current_cluster: str, + active_cluster: str, + active_clusters: list[str], + ) -> None: + super().__init__( + message, code, domain, current_cluster, active_cluster, active_clusters + ) self.domain = domain self.current_cluster = current_cluster self.active_cluster = active_cluster self.active_clusters = active_clusters + class ClientVersionNotSupportedError(CadenceError): - def __init__(self, message: str, code: grpc.StatusCode, feature_version: str, client_impl: str, supported_versions: str) -> None: - super().__init__(message, code, feature_version, client_impl, supported_versions) + def __init__( + self, + message: str, + code: grpc.StatusCode, + feature_version: str, + client_impl: str, + supported_versions: str, + ) -> None: + super().__init__( + message, code, feature_version, client_impl, supported_versions + ) self.feature_version = feature_version self.client_impl = client_impl self.supported_versions = supported_versions + class FeatureNotEnabledError(CadenceError): def __init__(self, message: str, code: grpc.StatusCode, feature_flag: str) -> None: super().__init__(message, code, feature_flag) self.feature_flag = feature_flag + class CancellationAlreadyRequestedError(CadenceError): pass + class DomainAlreadyExistsError(CadenceError): pass + class LimitExceededError(CadenceError): pass + class QueryFailedError(CadenceError): pass + class ServiceBusyError(CadenceError): def __init__(self, message: str, code: grpc.StatusCode, reason: str) -> None: super().__init__(message, code, reason) diff --git a/cadence/metrics/__init__.py b/cadence/metrics/__init__.py index a933fea..792b56a 100644 --- a/cadence/metrics/__init__.py +++ b/cadence/metrics/__init__.py @@ -7,6 +7,6 @@ "MetricsEmitter", "NoOpMetricsEmitter", "MetricType", - "PrometheusMetrics", + "PrometheusMetrics", "PrometheusConfig", ] diff --git a/cadence/metrics/constants.py b/cadence/metrics/constants.py index 9b04e31..ee69f4c 100644 --- a/cadence/metrics/constants.py +++ b/cadence/metrics/constants.py @@ -12,38 +12,56 @@ WORKFLOW_CONTINUE_AS_NEW_COUNTER = CADENCE_METRICS_PREFIX + "workflow-continue-as-new" WORKFLOW_END_TO_END_LATENCY = CADENCE_METRICS_PREFIX + "workflow-endtoend-latency" WORKFLOW_GET_HISTORY_COUNTER = CADENCE_METRICS_PREFIX + "workflow-get-history-total" -WORKFLOW_GET_HISTORY_FAILED_COUNTER = CADENCE_METRICS_PREFIX + "workflow-get-history-failed" -WORKFLOW_GET_HISTORY_SUCCEED_COUNTER = CADENCE_METRICS_PREFIX + "workflow-get-history-succeed" +WORKFLOW_GET_HISTORY_FAILED_COUNTER = ( + CADENCE_METRICS_PREFIX + "workflow-get-history-failed" +) +WORKFLOW_GET_HISTORY_SUCCEED_COUNTER = ( + CADENCE_METRICS_PREFIX + "workflow-get-history-succeed" +) WORKFLOW_GET_HISTORY_LATENCY = CADENCE_METRICS_PREFIX + "workflow-get-history-latency" -WORKFLOW_SIGNAL_WITH_START_COUNTER = CADENCE_METRICS_PREFIX + "workflow-signal-with-start" -WORKFLOW_SIGNAL_WITH_START_ASYNC_COUNTER = CADENCE_METRICS_PREFIX + "workflow-signal-with-start-async" +WORKFLOW_SIGNAL_WITH_START_COUNTER = ( + CADENCE_METRICS_PREFIX + "workflow-signal-with-start" +) +WORKFLOW_SIGNAL_WITH_START_ASYNC_COUNTER = ( + CADENCE_METRICS_PREFIX + "workflow-signal-with-start-async" +) DECISION_TIMEOUT_COUNTER = CADENCE_METRICS_PREFIX + "decision-timeout" # Decision Poll metrics DECISION_POLL_COUNTER = CADENCE_METRICS_PREFIX + "decision-poll-total" DECISION_POLL_FAILED_COUNTER = CADENCE_METRICS_PREFIX + "decision-poll-failed" -DECISION_POLL_TRANSIENT_FAILED_COUNTER = CADENCE_METRICS_PREFIX + "decision-poll-transient-failed" +DECISION_POLL_TRANSIENT_FAILED_COUNTER = ( + CADENCE_METRICS_PREFIX + "decision-poll-transient-failed" +) DECISION_POLL_NO_TASK_COUNTER = CADENCE_METRICS_PREFIX + "decision-poll-no-task" DECISION_POLL_SUCCEED_COUNTER = CADENCE_METRICS_PREFIX + "decision-poll-succeed" DECISION_POLL_LATENCY = CADENCE_METRICS_PREFIX + "decision-poll-latency" DECISION_POLL_INVALID_COUNTER = CADENCE_METRICS_PREFIX + "decision-poll-invalid" -DECISION_SCHEDULED_TO_START_LATENCY = CADENCE_METRICS_PREFIX + "decision-scheduled-to-start-latency" +DECISION_SCHEDULED_TO_START_LATENCY = ( + CADENCE_METRICS_PREFIX + "decision-scheduled-to-start-latency" +) DECISION_EXECUTION_FAILED_COUNTER = CADENCE_METRICS_PREFIX + "decision-execution-failed" DECISION_EXECUTION_LATENCY = CADENCE_METRICS_PREFIX + "decision-execution-latency" DECISION_RESPONSE_FAILED_COUNTER = CADENCE_METRICS_PREFIX + "decision-response-failed" DECISION_RESPONSE_LATENCY = CADENCE_METRICS_PREFIX + "decision-response-latency" DECISION_TASK_PANIC_COUNTER = CADENCE_METRICS_PREFIX + "decision-task-panic" DECISION_TASK_COMPLETED_COUNTER = CADENCE_METRICS_PREFIX + "decision-task-completed" -DECISION_TASK_FORCE_COMPLETED_COUNTER = CADENCE_METRICS_PREFIX + "decision-task-force-completed" +DECISION_TASK_FORCE_COMPLETED_COUNTER = ( + CADENCE_METRICS_PREFIX + "decision-task-force-completed" +) # Activity Poll metrics ACTIVITY_POLL_COUNTER = CADENCE_METRICS_PREFIX + "activity-poll-total" ACTIVITY_POLL_FAILED_COUNTER = CADENCE_METRICS_PREFIX + "activity-poll-failed" -ACTIVITY_POLL_TRANSIENT_FAILED_COUNTER = CADENCE_METRICS_PREFIX + "activity-poll-transient-failed" +ACTIVITY_POLL_TRANSIENT_FAILED_COUNTER = ( + CADENCE_METRICS_PREFIX + "activity-poll-transient-failed" +) ACTIVITY_POLL_NO_TASK_COUNTER = CADENCE_METRICS_PREFIX + "activity-poll-no-task" ACTIVITY_POLL_SUCCEED_COUNTER = CADENCE_METRICS_PREFIX + "activity-poll-succeed" ACTIVITY_POLL_LATENCY = CADENCE_METRICS_PREFIX + "activity-poll-latency" -ACTIVITY_SCHEDULED_TO_START_LATENCY = CADENCE_METRICS_PREFIX + "activity-scheduled-to-start-latency" +ACTIVITY_SCHEDULED_TO_START_LATENCY = ( + CADENCE_METRICS_PREFIX + "activity-scheduled-to-start-latency" +) ACTIVITY_EXECUTION_FAILED_COUNTER = CADENCE_METRICS_PREFIX + "activity-execution-failed" ACTIVITY_EXECUTION_LATENCY = CADENCE_METRICS_PREFIX + "activity-execution-latency" ACTIVITY_RESPONSE_LATENCY = CADENCE_METRICS_PREFIX + "activity-response-latency" @@ -53,9 +71,15 @@ ACTIVITY_TASK_COMPLETED_COUNTER = CADENCE_METRICS_PREFIX + "activity-task-completed" ACTIVITY_TASK_FAILED_COUNTER = CADENCE_METRICS_PREFIX + "activity-task-failed" ACTIVITY_TASK_CANCELED_COUNTER = CADENCE_METRICS_PREFIX + "activity-task-canceled" -ACTIVITY_TASK_COMPLETED_BY_ID_COUNTER = CADENCE_METRICS_PREFIX + "activity-task-completed-by-id" -ACTIVITY_TASK_FAILED_BY_ID_COUNTER = CADENCE_METRICS_PREFIX + "activity-task-failed-by-id" -ACTIVITY_TASK_CANCELED_BY_ID_COUNTER = CADENCE_METRICS_PREFIX + "activity-task-canceled-by-id" +ACTIVITY_TASK_COMPLETED_BY_ID_COUNTER = ( + CADENCE_METRICS_PREFIX + "activity-task-completed-by-id" +) +ACTIVITY_TASK_FAILED_BY_ID_COUNTER = ( + CADENCE_METRICS_PREFIX + "activity-task-failed-by-id" +) +ACTIVITY_TASK_CANCELED_BY_ID_COUNTER = ( + CADENCE_METRICS_PREFIX + "activity-task-canceled-by-id" +) # Local Activity metrics LOCAL_ACTIVITY_TOTAL_COUNTER = CADENCE_METRICS_PREFIX + "local-activity-total" @@ -63,12 +87,24 @@ LOCAL_ACTIVITY_CANCELED_COUNTER = CADENCE_METRICS_PREFIX + "local-activity-canceled" LOCAL_ACTIVITY_FAILED_COUNTER = CADENCE_METRICS_PREFIX + "local-activity-failed" LOCAL_ACTIVITY_PANIC_COUNTER = CADENCE_METRICS_PREFIX + "local-activity-panic" -LOCAL_ACTIVITY_EXECUTION_LATENCY = CADENCE_METRICS_PREFIX + "local-activity-execution-latency" -LOCALLY_DISPATCHED_ACTIVITY_POLL_COUNTER = CADENCE_METRICS_PREFIX + "locally-dispatched-activity-poll-total" -LOCALLY_DISPATCHED_ACTIVITY_POLL_NO_TASK_COUNTER = CADENCE_METRICS_PREFIX + "locally-dispatched-activity-poll-no-task" -LOCALLY_DISPATCHED_ACTIVITY_POLL_SUCCEED_COUNTER = CADENCE_METRICS_PREFIX + "locally-dispatched-activity-poll-succeed" -ACTIVITY_LOCAL_DISPATCH_FAILED_COUNTER = CADENCE_METRICS_PREFIX + "activity-local-dispatch-failed" -ACTIVITY_LOCAL_DISPATCH_SUCCEED_COUNTER = CADENCE_METRICS_PREFIX + "activity-local-dispatch-succeed" +LOCAL_ACTIVITY_EXECUTION_LATENCY = ( + CADENCE_METRICS_PREFIX + "local-activity-execution-latency" +) +LOCALLY_DISPATCHED_ACTIVITY_POLL_COUNTER = ( + CADENCE_METRICS_PREFIX + "locally-dispatched-activity-poll-total" +) +LOCALLY_DISPATCHED_ACTIVITY_POLL_NO_TASK_COUNTER = ( + CADENCE_METRICS_PREFIX + "locally-dispatched-activity-poll-no-task" +) +LOCALLY_DISPATCHED_ACTIVITY_POLL_SUCCEED_COUNTER = ( + CADENCE_METRICS_PREFIX + "locally-dispatched-activity-poll-succeed" +) +ACTIVITY_LOCAL_DISPATCH_FAILED_COUNTER = ( + CADENCE_METRICS_PREFIX + "activity-local-dispatch-failed" +) +ACTIVITY_LOCAL_DISPATCH_SUCCEED_COUNTER = ( + CADENCE_METRICS_PREFIX + "activity-local-dispatch-succeed" +) WORKER_PANIC_COUNTER = CADENCE_METRICS_PREFIX + "worker-panic" # Signal metrics @@ -98,4 +134,3 @@ REPLAY_FAILED_COUNTER = CADENCE_METRICS_PREFIX + "replay-failed" REPLAY_SKIPPED_COUNTER = CADENCE_METRICS_PREFIX + "replay-skipped" REPLAY_LATENCY = CADENCE_METRICS_PREFIX + "replay-latency" - diff --git a/cadence/metrics/metrics.py b/cadence/metrics/metrics.py index d2ca3d2..e5e8425 100644 --- a/cadence/metrics/metrics.py +++ b/cadence/metrics/metrics.py @@ -30,7 +30,6 @@ def gauge( """Send a gauge metric.""" ... - def histogram( self, key: str, value: float, tags: Optional[Dict[str, str]] = None ) -> None: @@ -51,10 +50,7 @@ def gauge( ) -> None: pass - def histogram( self, key: str, value: float, tags: Optional[Dict[str, str]] = None ) -> None: pass - - diff --git a/cadence/metrics/prometheus.py b/cadence/metrics/prometheus.py index 277c863..486a96a 100644 --- a/cadence/metrics/prometheus.py +++ b/cadence/metrics/prometheus.py @@ -107,7 +107,6 @@ def _get_or_create_histogram( return self._histograms[metric_name] - def counter( self, key: str, n: int = 1, tags: Optional[Dict[str, str]] = None ) -> None: @@ -140,7 +139,6 @@ def gauge( except Exception as e: logger.error(f"Failed to send gauge {key}: {e}") - def histogram( self, key: str, value: float, tags: Optional[Dict[str, str]] = None ) -> None: diff --git a/cadence/sample/__init__.py b/cadence/sample/__init__.py index b5fecce..76de7e6 100644 --- a/cadence/sample/__init__.py +++ b/cadence/sample/__init__.py @@ -1 +1 @@ -# Sample directory for cadence protobuf import tests \ No newline at end of file +# Sample directory for cadence protobuf import tests diff --git a/cadence/sample/client_example.py b/cadence/sample/client_example.py index 152f7a8..3ed111f 100644 --- a/cadence/sample/client_example.py +++ b/cadence/sample/client_example.py @@ -10,5 +10,6 @@ async def main(): worker = Worker(client, "task_list", Registry()) await worker.run() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/cadence/sample/grpc_usage_example.py b/cadence/sample/grpc_usage_example.py index 5bd9aa3..0f0afd3 100644 --- a/cadence/sample/grpc_usage_example.py +++ b/cadence/sample/grpc_usage_example.py @@ -8,14 +8,16 @@ from cadence.api.v1 import service_workflow_grpc, service_workflow, common -def create_grpc_channel(server_address: str = "localhost:7833", use_ssl: bool = False) -> grpc.Channel: +def create_grpc_channel( + server_address: str = "localhost:7833", use_ssl: bool = False +) -> grpc.Channel: """ Create a gRPC channel to connect to Cadence server. - + Args: server_address: The address of the Cadence server (host:port) use_ssl: Whether to use SSL/TLS for the connection - + Returns: grpc.Channel: The gRPC channel """ @@ -28,23 +30,27 @@ def create_grpc_channel(server_address: str = "localhost:7833", use_ssl: bool = return grpc.insecure_channel(server_address) -def create_workflow_client(channel: grpc.Channel) -> service_workflow_grpc.WorkflowAPIStub: +def create_workflow_client( + channel: grpc.Channel, +) -> service_workflow_grpc.WorkflowAPIStub: """ Create a gRPC client for the WorkflowAPI service. - + Args: channel: The gRPC channel - + Returns: WorkflowAPIStub: The gRPC client stub """ return service_workflow_grpc.WorkflowAPIStub(channel) -def example_start_workflow(client: service_workflow_grpc.WorkflowAPIStub, domain: str, workflow_id: str): +def example_start_workflow( + client: service_workflow_grpc.WorkflowAPIStub, domain: str, workflow_id: str +): """ Example of starting a workflow execution using gRPC. - + Args: client: The gRPC client domain: The Cadence domain @@ -60,7 +66,7 @@ def example_start_workflow(client: service_workflow_grpc.WorkflowAPIStub, domain request.execution_start_to_close_timeout.seconds = 3600 # 1 hour request.task_start_to_close_timeout.seconds = 60 # 1 minute request.identity = "python-client" - + try: # Make the gRPC call response = client.StartWorkflowExecution(request) @@ -71,11 +77,15 @@ def example_start_workflow(client: service_workflow_grpc.WorkflowAPIStub, domain return None -def example_describe_workflow(client: service_workflow_grpc.WorkflowAPIStub, domain: str, workflow_id: str, - run_id: str): +def example_describe_workflow( + client: service_workflow_grpc.WorkflowAPIStub, + domain: str, + workflow_id: str, + run_id: str, +): """ Example of describing a workflow execution using gRPC. - + Args: client: The gRPC client domain: The Cadence domain @@ -89,7 +99,7 @@ def example_describe_workflow(client: service_workflow_grpc.WorkflowAPIStub, dom execution.workflow_id = workflow_id execution.run_id = run_id request.workflow_execution.CopyFrom(execution) - + try: # Make the gRPC call response = client.DescribeWorkflowExecution(request) @@ -100,10 +110,15 @@ def example_describe_workflow(client: service_workflow_grpc.WorkflowAPIStub, dom return None -def example_get_workflow_history(client: service_workflow_grpc.WorkflowAPIStub, domain: str, workflow_id: str, run_id: str): +def example_get_workflow_history( + client: service_workflow_grpc.WorkflowAPIStub, + domain: str, + workflow_id: str, + run_id: str, +): """ Example of getting workflow execution history using gRPC. - + Args: client: The gRPC client domain: The Cadence domain @@ -118,7 +133,7 @@ def example_get_workflow_history(client: service_workflow_grpc.WorkflowAPIStub, execution.run_id = run_id request.workflow_execution.CopyFrom(execution) request.page_size = 100 - + try: # Make the gRPC call response = client.GetWorkflowExecutionHistory(request) @@ -129,10 +144,16 @@ def example_get_workflow_history(client: service_workflow_grpc.WorkflowAPIStub, return None -def example_query_workflow(client: service_workflow_grpc.WorkflowAPIStub, domain: str, workflow_id: str, run_id: str, query_type: str): +def example_query_workflow( + client: service_workflow_grpc.WorkflowAPIStub, + domain: str, + workflow_id: str, + run_id: str, + query_type: str, +): """ Example of querying a workflow using gRPC. - + Args: client: The gRPC client domain: The Cadence domain @@ -149,7 +170,7 @@ def example_query_workflow(client: service_workflow_grpc.WorkflowAPIStub, domain request.workflow_execution.CopyFrom(execution) request.query.query_type = query_type request.query.query_args.data = b"query arguments" # Serialized query arguments - + try: # Make the gRPC call response = client.QueryWorkflow(request) @@ -164,46 +185,46 @@ def main(): """Main example function.""" print("Cadence gRPC Client Example") print("=" * 40) - + # Configuration server_address = "localhost:7833" # Default Cadence gRPC port domain = "test-domain" workflow_id = "example-workflow-123" run_id = "example-run-456" - + try: # Create gRPC channel print(f"Connecting to Cadence server at {server_address}...") channel = create_grpc_channel(server_address) - + # Create gRPC client client = create_workflow_client(channel) print("✓ gRPC client created successfully") - + # Example 1: Start a workflow print("\n1. Starting a workflow...") example_start_workflow(client, domain, workflow_id) - + # Example 2: Describe a workflow print("\n2. Describing a workflow...") example_describe_workflow(client, domain, workflow_id, run_id) - + # Example 3: Get workflow history print("\n3. Getting workflow history...") example_get_workflow_history(client, domain, workflow_id, run_id) - + # Example 4: Query a workflow print("\n4. Querying a workflow...") example_query_workflow(client, domain, workflow_id, run_id, "status") - + except Exception as e: print(f"✗ Error: {e}") finally: # Close the channel - if 'channel' in locals(): + if "channel" in locals(): channel.close() print("\n✓ gRPC channel closed") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/cadence/sample/simple_usage_example.py b/cadence/sample/simple_usage_example.py index 6b3b373..2c94443 100644 --- a/cadence/sample/simple_usage_example.py +++ b/cadence/sample/simple_usage_example.py @@ -8,66 +8,68 @@ import os # Add the project root to the path so we can import cadence modules -project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +project_root = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) sys.path.insert(0, project_root) def example_workflow_execution(): """Example of creating and using WorkflowExecution objects.""" print("=== Workflow Execution Example ===") - + from cadence.api.v1 import common, workflow - + # Create a workflow execution wf_exec = common.WorkflowExecution() wf_exec.workflow_id = "my-workflow-123" wf_exec.run_id = "run-456" - + print("Created workflow execution:") print(f" - Workflow ID: {wf_exec.workflow_id}") print(f" - Run ID: {wf_exec.run_id}") - + # Create workflow execution info wf_info = workflow.WorkflowExecutionInfo() wf_info.workflow_execution.CopyFrom(wf_exec) wf_info.type.name = "MyWorkflowType" wf_info.start_time.seconds = 1234567890 wf_info.close_time.seconds = 1234567990 - + print("Created workflow execution info:") print(f" - Type: {wf_info.type.name}") print(f" - Start Time: {wf_info.start_time.seconds}") print(f" - Close Time: {wf_info.close_time.seconds}") - + return wf_exec, wf_info def example_domain_operations(): """Example of creating and using Domain objects.""" print("\n=== Domain Operations Example ===") - + from cadence.api.v1 import domain - + # Create a domain domain_obj = domain.Domain() domain_obj.name = "my-domain" domain_obj.status = domain.DOMAIN_STATUS_REGISTERED domain_obj.description = "My test domain" - + print("Created domain:") print(f" - Name: {domain_obj.name}") print(f" - Status: {domain_obj.status}") print(f" - Description: {domain_obj.description}") - + return domain_obj def example_enum_usage(): """Example of using enum values.""" print("\n=== Enum Usage Example ===") - + from cadence.api.v1 import workflow - + # Workflow execution close status print("Workflow Execution Close Status:") print(f" - COMPLETED: {workflow.WORKFLOW_EXECUTION_CLOSE_STATUS_COMPLETED}") @@ -75,14 +77,14 @@ def example_enum_usage(): print(f" - CANCELED: {workflow.WORKFLOW_EXECUTION_CLOSE_STATUS_CANCELED}") print(f" - TERMINATED: {workflow.WORKFLOW_EXECUTION_CLOSE_STATUS_TERMINATED}") print(f" - TIMED_OUT: {workflow.WORKFLOW_EXECUTION_CLOSE_STATUS_TIMED_OUT}") - + # Timeout types print("\nTimeout Types:") print(f" - START_TO_CLOSE: {workflow.TIMEOUT_TYPE_START_TO_CLOSE}") print(f" - SCHEDULE_TO_CLOSE: {workflow.TIMEOUT_TYPE_SCHEDULE_TO_CLOSE}") print(f" - SCHEDULE_TO_START: {workflow.TIMEOUT_TYPE_SCHEDULE_TO_START}") print(f" - HEARTBEAT: {workflow.TIMEOUT_TYPE_HEARTBEAT}") - + # Parent close policies print("\nParent Close Policies:") print(f" - TERMINATE: {workflow.PARENT_CLOSE_POLICY_TERMINATE}") @@ -93,28 +95,31 @@ def example_enum_usage(): def example_serialization(): """Example of serializing and deserializing protobuf objects.""" print("\n=== Serialization Example ===") - + from cadence.api.v1 import common - + # Create a workflow execution wf_exec = common.WorkflowExecution() wf_exec.workflow_id = "serialization-test" wf_exec.run_id = "run-789" - + # Serialize to bytes serialized = wf_exec.SerializeToString() print(f"Serialized size: {len(serialized)} bytes") - + # Deserialize from bytes new_wf_exec = common.WorkflowExecution() new_wf_exec.ParseFromString(serialized) - + print("Deserialized workflow execution:") print(f" - Workflow ID: {new_wf_exec.workflow_id}") print(f" - Run ID: {new_wf_exec.run_id}") - + # Verify they're equal - if wf_exec.workflow_id == new_wf_exec.workflow_id and wf_exec.run_id == new_wf_exec.run_id: + if ( + wf_exec.workflow_id == new_wf_exec.workflow_id + and wf_exec.run_id == new_wf_exec.run_id + ): print("✓ Serialization/deserialization successful!") else: print("✗ Serialization/deserialization failed!") @@ -124,26 +129,27 @@ def main(): """Main example function.""" print("🚀 Cadence Protobuf Usage Examples") print("=" * 50) - + try: # Run all examples example_workflow_execution() example_domain_operations() example_enum_usage() example_serialization() - + print("\n" + "=" * 50) print("✅ All examples completed successfully!") print("The protobuf modules are working correctly and ready for use.") - + except Exception as e: print(f"\n❌ Example failed: {e}") import traceback + traceback.print_exc() return 1 - + return 0 if __name__ == "__main__": - exit(main()) \ No newline at end of file + exit(main()) diff --git a/cadence/worker/__init__.py b/cadence/worker/__init__.py index 4084e9a..41c315e 100644 --- a/cadence/worker/__init__.py +++ b/cadence/worker/__init__.py @@ -1,9 +1,4 @@ - - -from ._worker import ( - Worker, - WorkerOptions -) +from ._worker import Worker, WorkerOptions from ._registry import ( Registry, @@ -13,6 +8,6 @@ __all__ = [ "Worker", "WorkerOptions", - 'Registry', - 'RegisterWorkflowOptions', + "Registry", + "RegisterWorkflowOptions", ] diff --git a/cadence/worker/_activity.py b/cadence/worker/_activity.py index 2e7591f..85f9565 100644 --- a/cadence/worker/_activity.py +++ b/cadence/worker/_activity.py @@ -2,7 +2,10 @@ from typing import Optional from cadence._internal.activity import ActivityExecutor -from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, PollForActivityTaskRequest +from cadence.api.v1.service_worker_pb2 import ( + PollForActivityTaskResponse, + PollForActivityTaskRequest, +) from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind from cadence.client import Client from cadence.worker._registry import Registry @@ -11,25 +14,42 @@ class ActivityWorker: - def __init__(self, client: Client, task_list: str, registry: Registry, options: WorkerOptions) -> None: + def __init__( + self, client: Client, task_list: str, registry: Registry, options: WorkerOptions + ) -> None: self._client = client self._task_list = task_list self._identity = options["identity"] max_concurrent = options["max_concurrent_activity_execution_size"] permits = asyncio.Semaphore(max_concurrent) - self._executor = ActivityExecutor(self._client, self._task_list, options["identity"], max_concurrent, registry.get_activity) - self._poller = Poller[PollForActivityTaskResponse](options["activity_task_pollers"], permits, self._poll, self._execute) + self._executor = ActivityExecutor( + self._client, + self._task_list, + options["identity"], + max_concurrent, + registry.get_activity, + ) + self._poller = Poller[PollForActivityTaskResponse]( + options["activity_task_pollers"], permits, self._poll, self._execute + ) # TODO: Local dispatch, local activities, actually running activities, etc async def run(self) -> None: await self._poller.run() async def _poll(self) -> Optional[PollForActivityTaskResponse]: - task: PollForActivityTaskResponse = await self._client.worker_stub.PollForActivityTask(PollForActivityTaskRequest( - domain=self._client.domain, - task_list=TaskList(name=self._task_list,kind=TaskListKind.TASK_LIST_KIND_NORMAL), - identity=self._identity, - ), timeout=_LONG_POLL_TIMEOUT) + task: PollForActivityTaskResponse = ( + await self._client.worker_stub.PollForActivityTask( + PollForActivityTaskRequest( + domain=self._client.domain, + task_list=TaskList( + name=self._task_list, kind=TaskListKind.TASK_LIST_KIND_NORMAL + ), + identity=self._identity, + ), + timeout=_LONG_POLL_TIMEOUT, + ) + ) if task.task_token: return task @@ -38,4 +58,3 @@ async def _poll(self) -> Optional[PollForActivityTaskResponse]: async def _execute(self, task: PollForActivityTaskResponse) -> None: await self._executor.execute(task) - diff --git a/cadence/worker/_base_task_handler.py b/cadence/worker/_base_task_handler.py index 3fda7e7..042f80b 100644 --- a/cadence/worker/_base_task_handler.py +++ b/cadence/worker/_base_task_handler.py @@ -4,20 +4,21 @@ logger = logging.getLogger(__name__) -T = TypeVar('T') +T = TypeVar("T") + class BaseTaskHandler(ABC, Generic[T]): """ Base task handler that provides common functionality for processing tasks. - + This abstract class defines the interface and common behavior for task handlers that process different types of tasks (workflow decisions, activities, etc.). """ - + def __init__(self, client, task_list: str, identity: str, **options): """ Initialize the base task handler. - + Args: client: The Cadence client instance task_list: The task list name @@ -28,41 +29,41 @@ def __init__(self, client, task_list: str, identity: str, **options): self._task_list = task_list self._identity = identity self._options = options - + async def handle_task(self, task: T) -> None: """ Handle a single task. - + This method provides the base implementation for task handling that includes: - Error handling - Cleanup - + Args: task: The task to handle """ try: # Handle the task implementation await self._handle_task_implementation(task) - + except Exception as e: logger.exception(f"Error handling task: {e}") await self.handle_task_failure(task, e) - + @abstractmethod async def _handle_task_implementation(self, task: T) -> None: """ Handle the actual task implementation. - + Args: task: The task to handle """ pass - + @abstractmethod async def handle_task_failure(self, task: T, error: Exception) -> None: """ Handle task processing failure. - + Args: task: The task that failed error: The exception that occurred diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py index 62f0edb..0f5f780 100644 --- a/cadence/worker/_decision_task_handler.py +++ b/cadence/worker/_decision_task_handler.py @@ -6,7 +6,7 @@ from cadence.api.v1.service_worker_pb2 import ( PollForDecisionTaskResponse, RespondDecisionTaskCompletedRequest, - RespondDecisionTaskFailedRequest + RespondDecisionTaskFailedRequest, ) from cadence.api.v1.workflow_pb2 import DecisionTaskFailedCause from cadence.client import Client @@ -17,18 +17,26 @@ logger = logging.getLogger(__name__) + class DecisionTaskHandler(BaseTaskHandler[PollForDecisionTaskResponse]): """ Task handler for processing decision tasks. - + This handler processes decision tasks and generates decisions using workflow engines. Uses a thread-safe cache to hold workflow engines for concurrent decision task handling. """ - - def __init__(self, client: Client, task_list: str, registry: Registry, identity: str = "unknown", **options): + + def __init__( + self, + client: Client, + task_list: str, + registry: Registry, + identity: str = "unknown", + **options, + ): """ Initialize the decision task handler. - + Args: client: The Cadence client instance task_list: The task list name @@ -41,27 +49,30 @@ def __init__(self, client: Client, task_list: str, registry: Registry, identity: # Thread-safe cache to hold workflow engines keyed by (workflow_id, run_id) self._workflow_engines: Dict[Tuple[str, str], WorkflowEngine] = {} self._cache_lock = threading.RLock() - - - async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -> None: + + async def _handle_task_implementation( + self, task: PollForDecisionTaskResponse + ) -> None: """ Handle a decision task implementation. - + Args: task: The decision task to handle """ # Extract workflow execution info workflow_execution = task.workflow_execution workflow_type = task.workflow_type - + if not workflow_execution or not workflow_type: - logger.error("Decision task missing workflow execution or type. Task: %r", task) + logger.error( + "Decision task missing workflow execution or type. Task: %r", task + ) raise ValueError("Missing workflow execution or type") - + workflow_id = workflow_execution.workflow_id run_id = workflow_execution.run_id workflow_type_name = workflow_type.name - + # This log matches the WorkflowEngine but at task handler level (like Java ReplayDecisionTaskHandler) logger.info( "Received decision task for workflow", @@ -71,10 +82,12 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - "run_id": run_id, "started_event_id": task.started_event_id, "attempt": task.attempt, - "task_token": task.task_token[:16].hex() if task.task_token else None # Log partial token for debugging - } + "task_token": task.task_token[:16].hex() + if task.task_token + else None, # Log partial token for debugging + }, ) - + try: workflow_definition = self._registry.get_workflow(workflow_type_name) except KeyError: @@ -84,19 +97,19 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - "workflow_type": workflow_type_name, "workflow_id": workflow_id, "run_id": run_id, - "error_type": "workflow_not_registered" - } + "error_type": "workflow_not_registered", + }, ) raise KeyError(f"Workflow type '{workflow_type_name}' not found") - + # Create workflow info and get or create workflow engine from cache workflow_info = WorkflowInfo( workflow_type=workflow_type_name, workflow_domain=self._client.domain, workflow_id=workflow_id, - workflow_run_id=run_id + workflow_run_id=run_id, ) - + # Use thread-safe cache to get or create workflow engine cache_key = (workflow_id, run_id) with self._cache_lock: @@ -105,12 +118,12 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - workflow_engine = WorkflowEngine( info=workflow_info, client=self._client, - workflow_definition=workflow_definition + workflow_definition=workflow_definition, ) self._workflow_engines[cache_key] = workflow_engine - + decision_result = await workflow_engine.process_decision(task) - + # Clean up completed workflows from cache to prevent memory leaks if workflow_engine._is_workflow_complete: with self._cache_lock: @@ -120,34 +133,38 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - extra={ "workflow_id": workflow_id, "run_id": run_id, - "cache_size": len(self._workflow_engines) - } + "cache_size": len(self._workflow_engines), + }, ) - + # Respond with the decisions await self._respond_decision_task_completed(task, decision_result) - + logger.info( "Successfully processed decision task", extra={ "workflow_type": workflow_type_name, "workflow_id": workflow_id, "run_id": run_id, - "started_event_id": task.started_event_id - } + "started_event_id": task.started_event_id, + }, ) - - async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Exception) -> None: + + async def handle_task_failure( + self, task: PollForDecisionTaskResponse, error: Exception + ) -> None: """ Handle decision task processing failure. - + Args: task: The task that failed error: The exception that occurred """ # Extract workflow context for error logging (matches Java ReplayDecisionTaskHandler error patterns) workflow_execution = task.workflow_execution - workflow_id = workflow_execution.workflow_id if workflow_execution else "unknown" + workflow_id = ( + workflow_execution.workflow_id if workflow_execution else "unknown" + ) run_id = workflow_execution.run_id if workflow_execution else "unknown" workflow_type = task.workflow_type.name if task.workflow_type else "unknown" @@ -161,23 +178,23 @@ async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Ex "started_event_id": task.started_event_id, "attempt": task.attempt, "error_type": type(error).__name__, - "error_message": str(error) + "error_message": str(error), }, - exc_info=True + exc_info=True, ) - + # Determine the failure cause cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION if isinstance(error, KeyError): cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE elif isinstance(error, ValueError): cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_BAD_SCHEDULE_ACTIVITY_ATTRIBUTES - + # Create error details # TODO: Use a data converter for error details serialization - error_message = str(error).encode('utf-8') + error_message = str(error).encode("utf-8") details = Payload(data=error_message) - + # Respond with failure try: await self._client.worker_stub.RespondDecisionTaskFailed( @@ -185,7 +202,7 @@ async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Ex task_token=task.task_token, cause=cause, identity=self._identity, - details=details + details=details, ) ) logger.info( @@ -194,8 +211,10 @@ async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Ex "workflow_id": workflow_id, "run_id": run_id, "cause": cause, - "task_token": task.task_token[:16].hex() if task.task_token else None - } + "task_token": task.task_token[:16].hex() + if task.task_token + else None, + }, ) except Exception as e: logger.error( @@ -204,15 +223,17 @@ async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Ex "workflow_id": workflow_id, "run_id": run_id, "original_error": type(error).__name__, - "response_error": type(e).__name__ + "response_error": type(e).__name__, }, - exc_info=True + exc_info=True, ) - - async def _respond_decision_task_completed(self, task: PollForDecisionTaskResponse, decision_result: DecisionResult) -> None: + + async def _respond_decision_task_completed( + self, task: PollForDecisionTaskResponse, decision_result: DecisionResult + ) -> None: """ Respond to the service that the decision task has been completed. - + Args: task: The original decision task decision_result: The result containing decisions and query results @@ -222,9 +243,9 @@ async def _respond_decision_task_completed(self, task: PollForDecisionTaskRespon task_token=task.task_token, decisions=decision_result.decisions, identity=self._identity, - return_new_decision_task=True + return_new_decision_task=True, ) - + await self._client.worker_stub.RespondDecisionTaskCompleted(request) # Log completion response (matches Java ReplayDecisionTaskHandler trace/debug patterns) @@ -232,28 +253,42 @@ async def _respond_decision_task_completed(self, task: PollForDecisionTaskRespon logger.debug( "Decision task completion response sent", extra={ - "workflow_type": task.workflow_type.name if task.workflow_type else "unknown", - "workflow_id": workflow_execution.workflow_id if workflow_execution else "unknown", - "run_id": workflow_execution.run_id if workflow_execution else "unknown", + "workflow_type": task.workflow_type.name + if task.workflow_type + else "unknown", + "workflow_id": workflow_execution.workflow_id + if workflow_execution + else "unknown", + "run_id": workflow_execution.run_id + if workflow_execution + else "unknown", "started_event_id": task.started_event_id, "decisions_count": len(decision_result.decisions), "return_new_decision_task": True, - "task_token": task.task_token[:16].hex() if task.task_token else None - } + "task_token": task.task_token[:16].hex() + if task.task_token + else None, + }, ) - + except Exception as e: workflow_execution = task.workflow_execution logger.error( "Error responding to decision task completion", extra={ - "workflow_type": task.workflow_type.name if task.workflow_type else "unknown", - "workflow_id": workflow_execution.workflow_id if workflow_execution else "unknown", - "run_id": workflow_execution.run_id if workflow_execution else "unknown", + "workflow_type": task.workflow_type.name + if task.workflow_type + else "unknown", + "workflow_id": workflow_execution.workflow_id + if workflow_execution + else "unknown", + "run_id": workflow_execution.run_id + if workflow_execution + else "unknown", "started_event_id": task.started_event_id, "decisions_count": len(decision_result.decisions), - "error_type": type(e).__name__ + "error_type": type(e).__name__, }, - exc_info=True + exc_info=True, ) raise diff --git a/cadence/worker/_poller.py b/cadence/worker/_poller.py index a185d27..4c259dc 100644 --- a/cadence/worker/_poller.py +++ b/cadence/worker/_poller.py @@ -4,10 +4,17 @@ logger = logging.getLogger(__name__) -T = TypeVar('T') +T = TypeVar("T") + class Poller(Generic[T]): - def __init__(self, num_tasks: int, permits: asyncio.Semaphore, poll: Callable[[], Awaitable[Optional[T]]], callback: Callable[[T], Awaitable[None]]) -> None: + def __init__( + self, + num_tasks: int, + permits: asyncio.Semaphore, + poll: Callable[[], Awaitable[Optional[T]]], + callback: Callable[[T], Awaitable[None]], + ) -> None: self._num_tasks = num_tasks self._permits = permits self._poll = poll @@ -20,8 +27,7 @@ async def run(self) -> None: for i in range(self._num_tasks): tg.create_task(self._poll_loop()) except asyncio.CancelledError: - pass - + pass async def _poll_loop(self) -> None: while True: @@ -30,8 +36,7 @@ async def _poll_loop(self) -> None: except asyncio.CancelledError as e: raise e except Exception: - logger.exception('Exception while polling') - + logger.exception("Exception while polling") async def _poll_and_dispatch(self) -> None: await self._permits.acquire() @@ -54,6 +59,6 @@ async def _execute_callback(self, task: T) -> None: try: await self._callback(task) except Exception: - logger.exception('Exception during callback') + logger.exception("Exception during callback") finally: - self._permits.release() \ No newline at end of file + self._permits.release() diff --git a/cadence/worker/_registry.py b/cadence/worker/_registry.py index f45c42a..05bff67 100644 --- a/cadence/worker/_registry.py +++ b/cadence/worker/_registry.py @@ -7,29 +7,47 @@ """ import logging -from typing import Callable, Dict, Optional, Unpack, TypedDict, overload, Type, Union, TypeVar -from cadence.activity import ActivityDefinitionOptions, ActivityDefinition, ActivityDecorator, P, T +from typing import ( + Callable, + Dict, + Optional, + Unpack, + TypedDict, + overload, + Type, + Union, + TypeVar, +) +from cadence.activity import ( + ActivityDefinitionOptions, + ActivityDefinition, + ActivityDecorator, + P, + T, +) from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions logger = logging.getLogger(__name__) # TypeVar for workflow class types -W = TypeVar('W') +W = TypeVar("W") class RegisterWorkflowOptions(TypedDict, total=False): """Options for registering a workflow.""" + name: Optional[str] alias: Optional[str] + class Registry: """ Registry for managing workflows and activities. - + This class provides functionality to register, retrieve, and manage workflows and activities in a Cadence application. """ - + def __init__(self) -> None: """Initialize the registry.""" self._workflows: Dict[str, WorkflowDefinition] = {} @@ -37,9 +55,7 @@ def __init__(self) -> None: self._workflow_aliases: Dict[str, str] = {} # alias -> name mapping def workflow( - self, - cls: Optional[Type[W]] = None, - **kwargs: Unpack[RegisterWorkflowOptions] + self, cls: Optional[Type[W]] = None, **kwargs: Unpack[RegisterWorkflowOptions] ) -> Union[Type[W], Callable[[Type[W]], Type[W]]]: """ Register a workflow class. @@ -61,7 +77,7 @@ def workflow( options = RegisterWorkflowOptions(**kwargs) def decorator(target: Type[W]) -> Type[W]: - workflow_name = options.get('name') or target.__name__ + workflow_name = options.get("name") or target.__name__ if workflow_name in self._workflows: raise KeyError(f"Workflow '{workflow_name}' is already registered") @@ -72,7 +88,7 @@ def decorator(target: Type[W]) -> Type[W]: self._workflows[workflow_name] = workflow_def # Register alias if provided - alias = options.get('alias') + alias = options.get("alias") if alias: if alias in self._workflow_aliases: raise KeyError(f"Workflow alias '{alias}' is already registered") @@ -86,26 +102,30 @@ def decorator(target: Type[W]) -> Type[W]: return decorator(cls) @overload - def activity(self, func: Callable[P, T]) -> ActivityDefinition[P, T]: - ... + def activity(self, func: Callable[P, T]) -> ActivityDefinition[P, T]: ... @overload - def activity(self, **kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator: - ... - - def activity(self, func: Callable[P, T] | None = None, **kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator | ActivityDefinition[P, T]: + def activity( + self, **kwargs: Unpack[ActivityDefinitionOptions] + ) -> ActivityDecorator: ... + + def activity( + self, + func: Callable[P, T] | None = None, + **kwargs: Unpack[ActivityDefinitionOptions], + ) -> ActivityDecorator | ActivityDefinition[P, T]: """ Register an activity function. - + This method can be used as a decorator or called directly. - + Args: func: The activity function to register **kwargs: Options for registration (name, alias) - + Returns: The decorated function or the function itself - + Raises: KeyError: If activity name already exists """ @@ -131,7 +151,6 @@ def register_activities(self, obj: object) -> None: for defn in activities: self._register_activity(defn) - def register_activity(self, defn: Callable) -> None: if not isinstance(defn, ActivityDefinition): raise ValueError(f"{defn.__qualname__} must have @activity.defn decorator") @@ -143,7 +162,6 @@ def _register_activity(self, defn: ActivityDefinition) -> None: self._activities[defn.name] = defn - def get_workflow(self, name: str) -> WorkflowDefinition: """ Get a registered workflow by name. @@ -164,23 +182,23 @@ def get_workflow(self, name: str) -> WorkflowDefinition: raise KeyError(f"Workflow '{name}' not found in registry") return self._workflows[actual_name] - + def get_activity(self, name: str) -> ActivityDefinition: """ Get a registered activity by name. - + Args: name: Name or alias of the activity - + Returns: The activity function - + Raises: KeyError: If activity is not found """ return self._activities[name] - def __add__(self, other: 'Registry') -> 'Registry': + def __add__(self, other: "Registry") -> "Registry": result = Registry() for name, fn in self._activities.items(): result._register_activity(fn) @@ -190,13 +208,14 @@ def __add__(self, other: 'Registry') -> 'Registry': return result @staticmethod - def of(*args: 'Registry') -> 'Registry': + def of(*args: "Registry") -> "Registry": result = Registry() for other in args: result += other return result + def _find_activity_definitions(instance: object) -> list[ActivityDefinition]: attr_to_def = {} for t in instance.__class__.__mro__: @@ -206,14 +225,20 @@ def _find_activity_definitions(instance: object) -> list[ActivityDefinition]: value = getattr(t, attr) if isinstance(value, ActivityDefinition): if attr in attr_to_def: - raise ValueError(f"'{attr}' was overridden with a duplicate activity definition") + raise ValueError( + f"'{attr}' was overridden with a duplicate activity definition" + ) attr_to_def[attr] = value result: list[ActivityDefinition] = [] for attr, definition in attr_to_def.items(): - result.append(ActivityDefinition(getattr(instance, attr), definition.name, definition.strategy, definition.params)) + result.append( + ActivityDefinition( + getattr(instance, attr), + definition.name, + definition.strategy, + definition.params, + ) + ) return result - - - \ No newline at end of file diff --git a/cadence/worker/_types.py b/cadence/worker/_types.py index 8b16fed..ecf620f 100644 --- a/cadence/worker/_types.py +++ b/cadence/worker/_types.py @@ -12,6 +12,7 @@ class WorkerOptions(TypedDict, total=False): disable_activity_worker: bool identity: str + _DEFAULT_WORKER_OPTIONS: WorkerOptions = { "max_concurrent_activity_execution_size": 1000, "max_concurrent_decision_task_execution_size": 1000, diff --git a/cadence/worker/_worker.py b/cadence/worker/_worker.py index ff273ad..6d57d30 100644 --- a/cadence/worker/_worker.py +++ b/cadence/worker/_worker.py @@ -10,8 +10,13 @@ class Worker: - - def __init__(self, client: Client, task_list: str, registry: Registry, **kwargs: Unpack[WorkerOptions]) -> None: + def __init__( + self, + client: Client, + task_list: str, + registry: Registry, + **kwargs: Unpack[WorkerOptions], + ) -> None: self._client = client self._task_list = task_list @@ -21,7 +26,6 @@ def __init__(self, client: Client, task_list: str, registry: Registry, **kwargs: self._activity_worker = ActivityWorker(client, task_list, registry, options) self._decision_worker = DecisionWorker(client, task_list, registry, options) - async def run(self) -> None: async with asyncio.TaskGroup() as tg: if not self._options["disable_workflow_worker"]: @@ -30,8 +34,9 @@ async def run(self) -> None: tg.create_task(self._activity_worker.run()) - -def _validate_and_copy_defaults(client: Client, task_list: str, options: WorkerOptions) -> None: +def _validate_and_copy_defaults( + client: Client, task_list: str, options: WorkerOptions +) -> None: if "identity" not in options: options["identity"] = f"{client.identity}@{task_list}@{uuid.uuid4()}" diff --git a/cadence/workflow.py b/cadence/workflow.py index 14cabec..0e346ea 100644 --- a/cadence/workflow.py +++ b/cadence/workflow.py @@ -2,16 +2,27 @@ from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass -from typing import Iterator, Callable, TypeVar, TypedDict, Type, cast, Any, Optional, Union +from typing import ( + Iterator, + Callable, + TypeVar, + TypedDict, + Type, + cast, + Any, + Optional, + Union, +) import inspect from cadence.client import Client -T = TypeVar('T', bound=Callable[..., Any]) +T = TypeVar("T", bound=Callable[..., Any]) class WorkflowDefinitionOptions(TypedDict, total=False): """Options for defining a workflow.""" + name: str @@ -43,7 +54,7 @@ def get_run_method(self, instance: Any) -> Callable: return cast(Callable, getattr(instance, self._run_method_name)) @staticmethod - def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> 'WorkflowDefinition': + def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> "WorkflowDefinition": """ Wrap a class as a WorkflowDefinition. @@ -64,7 +75,7 @@ def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> 'WorkflowDefinition': # Validate that the class has exactly one run method and find it run_method_name = None for attr_name in dir(cls): - if attr_name.startswith('_'): + if attr_name.startswith("_"): continue attr = getattr(cls, attr_name) @@ -72,15 +83,17 @@ def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> 'WorkflowDefinition': continue # Check for workflow run method - if hasattr(attr, '_workflow_run'): + if hasattr(attr, "_workflow_run"): if run_method_name is not None: - raise ValueError(f"Multiple @workflow.run methods found in class {cls.__name__}") + raise ValueError( + f"Multiple @workflow.run methods found in class {cls.__name__}" + ) run_method_name = attr_name if run_method_name is None: raise ValueError(f"No @workflow.run method found in class {cls.__name__}") - return WorkflowDefinition(cls, name, run_method_name) + return WorkflowDefinition(cls, name, run_method_name) def run(func: Optional[T] = None) -> Union[T, Callable[[T], T]]: @@ -101,15 +114,16 @@ async def my_workflow(self): Returns: The decorated method with workflow run metadata - + Raises: ValueError: If the function is not async """ + def decorator(f: T) -> T: # Validate that the function is async if not inspect.iscoroutinefunction(f): raise ValueError(f"Workflow run method '{f.__name__}' must be async") - + # Attach metadata to the function f._workflow_run = True # type: ignore return f @@ -130,16 +144,15 @@ class WorkflowInfo: workflow_id: str workflow_run_id: str + class WorkflowContext(ABC): - _var: ContextVar['WorkflowContext'] = ContextVar("workflow") + _var: ContextVar["WorkflowContext"] = ContextVar("workflow") @abstractmethod - def info(self) -> WorkflowInfo: - ... + def info(self) -> WorkflowInfo: ... @abstractmethod - def client(self) -> Client: - ... + def client(self) -> Client: ... @contextmanager def _activate(self) -> Iterator[None]: @@ -152,5 +165,5 @@ def is_set() -> bool: return WorkflowContext._var.get(None) is not None @staticmethod - def get() -> 'WorkflowContext': + def get() -> "WorkflowContext": return WorkflowContext._var.get() diff --git a/scripts/dev.py b/scripts/dev.py index 143d9fc..e626ca6 100755 --- a/scripts/dev.py +++ b/scripts/dev.py @@ -29,7 +29,9 @@ def install(): def install_dev(): """Install the package with development dependencies.""" - return run_command("uv pip install -e '.[dev]'", "Installing package with dev dependencies") + return run_command( + "uv pip install -e '.[dev]'", "Installing package with dev dependencies" + ) def test(): @@ -39,7 +41,10 @@ def test(): def test_cov(): """Run tests with coverage.""" - return run_command("uv run pytest --cov=cadence --cov-report=html --cov-report=term-missing", "Running tests with coverage") + return run_command( + "uv run pytest --cov=cadence --cov-report=html --cov-report=term-missing", + "Running tests with coverage", + ) def lint(): @@ -93,7 +98,9 @@ def clean(): run_command(f"rm -rf {dir_pattern}", f"Removing {dir_pattern}") # Remove Python cache files - run_command("find . -type d -name __pycache__ -delete", "Removing __pycache__ directories") + run_command( + "find . -type d -name __pycache__ -delete", "Removing __pycache__ directories" + ) run_command("find . -type f -name '*.pyc' -delete", "Removing .pyc files") print("✓ Clean completed") @@ -112,7 +119,10 @@ def protobuf(): def docs(): """Build documentation.""" - return run_command("uv run sphinx-build -b html docs/source docs/build/html", "Building documentation") + return run_command( + "uv run sphinx-build -b html docs/source docs/build/html", + "Building documentation", + ) def check(): @@ -128,11 +138,26 @@ def check(): def main(): """Main function.""" - parser = argparse.ArgumentParser(description="Development script for Cadence Python client") - parser.add_argument("command", choices=[ - "install", "install-dev", "test", "test-cov", "lint", "format", - "clean", "build", "protobuf", "docs", "check" - ], help="Command to run") + parser = argparse.ArgumentParser( + description="Development script for Cadence Python client" + ) + parser.add_argument( + "command", + choices=[ + "install", + "install-dev", + "test", + "test-cov", + "lint", + "format", + "clean", + "build", + "protobuf", + "docs", + "check", + ], + help="Command to run", + ) args = parser.parse_args() diff --git a/scripts/generate_proto.py b/scripts/generate_proto.py index e50efd5..8cc7b5f 100644 --- a/scripts/generate_proto.py +++ b/scripts/generate_proto.py @@ -10,6 +10,7 @@ import shutil import runpy + def get_project_root() -> Path: try: return Path( @@ -20,6 +21,7 @@ def get_project_root() -> Path: except Exception as e: raise RuntimeError("Error: Could not determine project root from git:", e) + def generate_init_file(output_dir: Path) -> None: """Generate the __init__.py file for cadence/api/v1 with clean imports.""" v1_dir = output_dir / "cadence" / "api" / "v1" @@ -60,11 +62,12 @@ def generate_init_file(output_dir: Path) -> None: content += "]\n" # Write the file - with open(init_file, 'w') as f: + with open(init_file, "w") as f: f.write(content) print(f" ✓ Generated {init_file} with {len(pb2_files)} modules") + def setup_temp_proto_structure(proto_dir: Path, temp_dir: Path) -> None: """Create a temporary directory with proto files in the proper structure for cadence.api.v1 imports.""" print("Setting up temporary proto structure...") @@ -78,24 +81,28 @@ def setup_temp_proto_structure(proto_dir: Path, temp_dir: Path) -> None: # Copy all proto files from proto_dir to temp_dir for proto_file in proto_file_dir.glob("*.proto"): # Copy the proto file and update import statements - with open(proto_file, 'r') as src_file: + with open(proto_file, "r") as src_file: content = src_file.read() # Update import statements to remove 'uber/' prefix # Replace "uber/cadence/api/v1/" with "cadence/api/v1/" - updated_content = content.replace('import "uber/cadence/api/v1/', 'import "cadence/api/v1/') + updated_content = content.replace( + 'import "uber/cadence/api/v1/', 'import "cadence/api/v1/' + ) # Write the updated content to the target file - with open(output_dir / proto_file.name, 'w') as dst_file: + with open(output_dir / proto_file.name, "w") as dst_file: dst_file.write(updated_content) print(f" ✓ Copied and updated {proto_file.name}") + def delete_temp_dir(temp_dir: Path): if temp_dir.exists(): shutil.rmtree(temp_dir) print(f"Deleted temp directory: {temp_dir}") + def generate_protobuf_files(temp_dir: Path, gen_dir: Path) -> None: # Find all .proto files in the cadence/api/v1 directory proto_files = list((temp_dir / "cadence/api/v1").glob("*.proto")) @@ -112,18 +119,26 @@ def generate_protobuf_files(temp_dir: Path, gen_dir: Path) -> None: original_argv = sys.argv sys.argv = [ "grpc_tools.protoc", - "--proto_path", str(temp_dir), - "--python_out", str(gen_dir), - "--pyi_out", str(gen_dir), - "--grpc_python_out", str(gen_dir) + "--proto_path", + str(temp_dir), + "--python_out", + str(gen_dir), + "--pyi_out", + str(gen_dir), + "--grpc_python_out", + str(gen_dir), ] + proto_file_paths try: runpy.run_module("grpc_tools.protoc", run_name="__main__", alter_sys=True) - print(f"Successfully generated protobuf files using runpy for {len(proto_files)} files") + print( + f"Successfully generated protobuf files using runpy for {len(proto_files)} files" + ) except SystemExit as e: if e.code == 0: - print(f"Successfully generated protobuf files using runpy for {len(proto_files)} files") + print( + f"Successfully generated protobuf files using runpy for {len(proto_files)} files" + ) else: print("Error running grpc_tools.protoc via runpy {}", e) raise e @@ -131,6 +146,7 @@ def generate_protobuf_files(temp_dir: Path, gen_dir: Path) -> None: # Restore original argv sys.argv = original_argv + def main(): project_root = get_project_root() @@ -144,5 +160,6 @@ def main(): generate_init_file(gen_dir) delete_temp_dir(temp_dir) + if __name__ == "__main__": main() diff --git a/tests/cadence/_internal/activity/test_activity_executor.py b/tests/cadence/_internal/activity/test_activity_executor.py index abc02cb..5bee843 100644 --- a/tests/cadence/_internal/activity/test_activity_executor.py +++ b/tests/cadence/_internal/activity/test_activity_executor.py @@ -9,9 +9,20 @@ from cadence import activity, Client from cadence._internal.activity import ActivityExecutor from cadence.activity import ActivityInfo, ActivityDefinition -from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure, WorkflowType -from cadence.api.v1.service_worker_pb2 import RespondActivityTaskCompletedResponse, PollForActivityTaskResponse, \ - RespondActivityTaskCompletedRequest, RespondActivityTaskFailedResponse, RespondActivityTaskFailedRequest +from cadence.api.v1.common_pb2 import ( + WorkflowExecution, + ActivityType, + Payload, + Failure, + WorkflowType, +) +from cadence.api.v1.service_worker_pb2 import ( + RespondActivityTaskCompletedResponse, + PollForActivityTaskResponse, + RespondActivityTaskCompletedRequest, + RespondActivityTaskFailedResponse, + RespondActivityTaskFailedRequest, +) from cadence.data_converter import DefaultDataConverter from cadence.worker import Registry @@ -26,33 +37,42 @@ def client() -> Client: async def test_activity_async_success(client): worker_stub = client.worker_stub - worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + worker_stub.RespondActivityTaskCompleted = AsyncMock( + return_value=RespondActivityTaskCompletedResponse() + ) reg = Registry() + @reg.activity(name="activity_type") async def activity_fn(): return "success" - executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) + executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity) await executor.execute(fake_task("activity_type", "")) - worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( - task_token=b'task_token', - result=Payload(data='"success"'.encode()), - identity='identity', - )) + worker_stub.RespondActivityTaskCompleted.assert_called_once_with( + RespondActivityTaskCompletedRequest( + task_token=b"task_token", + result=Payload(data='"success"'.encode()), + identity="identity", + ) + ) + async def test_activity_async_failure(client): worker_stub = client.worker_stub - worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) + worker_stub.RespondActivityTaskFailed = AsyncMock( + return_value=RespondActivityTaskFailedResponse() + ) reg = Registry() + @reg.activity(name="activity_type") async def activity_fn(): raise KeyError("failure") - executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) + executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity) await executor.execute(fake_task("activity_type", "")) @@ -64,37 +84,47 @@ async def activity_fn(): assert 'raise KeyError("failure")' in call.failure.details.decode() call.failure.details = bytes() assert call == RespondActivityTaskFailedRequest( - task_token=b'task_token', + task_token=b"task_token", failure=Failure( reason="KeyError", ), - identity='identity', + identity="identity", ) + async def test_activity_args(client): worker_stub = client.worker_stub - worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + worker_stub.RespondActivityTaskCompleted = AsyncMock( + return_value=RespondActivityTaskCompletedResponse() + ) reg = Registry() + @reg.activity(name="activity_type") async def activity_fn(first: str, second: str): return " ".join([first, second]) - executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) + executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity) await executor.execute(fake_task("activity_type", '"hello" "world"')) - worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( - task_token=b'task_token', - result=Payload(data='"hello world"'.encode()), - identity='identity', - )) + worker_stub.RespondActivityTaskCompleted.assert_called_once_with( + RespondActivityTaskCompletedRequest( + task_token=b"task_token", + result=Payload(data='"hello world"'.encode()), + identity="identity", + ) + ) + async def test_activity_sync_success(client): worker_stub = client.worker_stub - worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + worker_stub.RespondActivityTaskCompleted = AsyncMock( + return_value=RespondActivityTaskCompletedResponse() + ) reg = Registry() + @reg.activity(name="activity_type") def activity_fn(): try: @@ -103,25 +133,31 @@ def activity_fn(): return "success" raise RuntimeError("expected to be running outside of the event loop") - executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) + executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity) await executor.execute(fake_task("activity_type", "")) - worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( - task_token=b'task_token', - result=Payload(data='"success"'.encode()), - identity='identity', - )) + worker_stub.RespondActivityTaskCompleted.assert_called_once_with( + RespondActivityTaskCompletedRequest( + task_token=b"task_token", + result=Payload(data='"success"'.encode()), + identity="identity", + ) + ) + async def test_activity_sync_failure(client): worker_stub = client.worker_stub - worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) + worker_stub.RespondActivityTaskFailed = AsyncMock( + return_value=RespondActivityTaskFailedResponse() + ) reg = Registry() + @reg.activity(name="activity_type") def activity_fn(): raise KeyError("failure") - executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) + executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity) await executor.execute(fake_task("activity_type", "")) @@ -133,21 +169,24 @@ def activity_fn(): assert 'raise KeyError("failure")' in call.failure.details.decode() call.failure.details = bytes() assert call == RespondActivityTaskFailedRequest( - task_token=b'task_token', + task_token=b"task_token", failure=Failure( reason="KeyError", ), - identity='identity', + identity="identity", ) + async def test_activity_unknown(client): worker_stub = client.worker_stub - worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) + worker_stub.RespondActivityTaskFailed = AsyncMock( + return_value=RespondActivityTaskFailedResponse() + ) def registry(name: str) -> ActivityDefinition: raise KeyError(f"unknown activity: {name}") - executor = ActivityExecutor(client, 'task_list', 'identity', 1, registry) + executor = ActivityExecutor(client, "task_list", "identity", 1, registry) await executor.execute(fake_task("activity_type", "")) @@ -155,20 +194,24 @@ def registry(name: str) -> ActivityDefinition: call = worker_stub.RespondActivityTaskFailed.call_args[0][0] - assert 'Activity type not found: activity_type' in call.failure.details.decode() + assert "Activity type not found: activity_type" in call.failure.details.decode() call.failure.details = bytes() assert call == RespondActivityTaskFailedRequest( - task_token=b'task_token', + task_token=b"task_token", failure=Failure( reason="KeyError", ), - identity='identity', + identity="identity", ) + async def test_activity_context(client): worker_stub = client.worker_stub - worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + worker_stub.RespondActivityTaskCompleted = AsyncMock( + return_value=RespondActivityTaskCompletedResponse() + ) reg = Registry() + @reg.activity(name="activity_type") async def activity_fn(): assert fake_info("activity_type") == activity.info() @@ -176,21 +219,27 @@ async def activity_fn(): assert activity.client() is not None return "success" - executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) + executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity) await executor.execute(fake_task("activity_type", "")) - worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( - task_token=b'task_token', - result=Payload(data='"success"'.encode()), - identity='identity', - )) + worker_stub.RespondActivityTaskCompleted.assert_called_once_with( + RespondActivityTaskCompletedRequest( + task_token=b"task_token", + result=Payload(data='"success"'.encode()), + identity="identity", + ) + ) + async def test_activity_context_sync(client): worker_stub = client.worker_stub - worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + worker_stub.RespondActivityTaskCompleted = AsyncMock( + return_value=RespondActivityTaskCompletedResponse() + ) reg = Registry() + @reg.activity(name="activity_type") def activity_fn(): assert fake_info("activity_type") == activity.info() @@ -199,20 +248,22 @@ def activity_fn(): activity.client() return "success" - executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) + executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity) await executor.execute(fake_task("activity_type", "")) - worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( - task_token=b'task_token', - result=Payload(data='"success"'.encode()), - identity='identity', - )) + worker_stub.RespondActivityTaskCompleted.assert_called_once_with( + RespondActivityTaskCompletedRequest( + task_token=b"task_token", + result=Payload(data='"success"'.encode()), + identity="identity", + ) + ) def fake_info(activity_type: str) -> ActivityInfo: return ActivityInfo( - task_token=b'task_token', + task_token=b"task_token", workflow_domain="workflow_domain", workflow_id="workflow_id", workflow_run_id="run_id", @@ -222,14 +273,15 @@ def fake_info(activity_type: str) -> ActivityInfo: workflow_type="workflow_type", task_list="task_list", heartbeat_timeout=timedelta(seconds=1), - scheduled_timestamp=datetime(2020, 1, 2 ,3), - started_timestamp=datetime(2020, 1, 2 ,4), + scheduled_timestamp=datetime(2020, 1, 2, 3), + started_timestamp=datetime(2020, 1, 2, 4), start_to_close_timeout=timedelta(seconds=2), ) + def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskResponse: return PollForActivityTaskResponse( - task_token=b'task_token', + task_token=b"task_token", workflow_domain="workflow_domain", workflow_type=WorkflowType(name="workflow_type"), workflow_execution=WorkflowExecution( @@ -246,6 +298,7 @@ def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskRespons start_to_close_timeout=from_timedelta(timedelta(seconds=2)), ) + def from_datetime(time: datetime) -> Timestamp: t = Timestamp() t.FromDatetime(time) diff --git a/tests/cadence/_internal/rpc/test_error.py b/tests/cadence/_internal/rpc/test_error.py index 8ca0c3e..2b67a7c 100644 --- a/tests/cadence/_internal/rpc/test_error.py +++ b/tests/cadence/_internal/rpc/test_error.py @@ -39,64 +39,138 @@ def fake_service(): yield fake sync_server.stop(grace=None) + @pytest.mark.usefixtures("fake_service") @pytest.mark.parametrize( "err,expected", [ - pytest.param(None, None,id="no error"), + pytest.param(None, None, id="no error"), pytest.param( - error_pb2.WorkflowExecutionAlreadyStartedError(start_request_id="start_request", run_id="run_id"), - error.WorkflowExecutionAlreadyStartedError(message="message", code=StatusCode.INVALID_ARGUMENT, start_request_id="start_request", run_id="run_id"), - id="WorkflowExecutionAlreadyStartedError"), + error_pb2.WorkflowExecutionAlreadyStartedError( + start_request_id="start_request", run_id="run_id" + ), + error.WorkflowExecutionAlreadyStartedError( + message="message", + code=StatusCode.INVALID_ARGUMENT, + start_request_id="start_request", + run_id="run_id", + ), + id="WorkflowExecutionAlreadyStartedError", + ), pytest.param( - error_pb2.EntityNotExistsError(current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]), - error.EntityNotExistsError(message="message", code=StatusCode.INVALID_ARGUMENT, current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]), - id="EntityNotExistsError"), + error_pb2.EntityNotExistsError( + current_cluster="current_cluster", + active_cluster="active_cluster", + active_clusters=["active_clusters"], + ), + error.EntityNotExistsError( + message="message", + code=StatusCode.INVALID_ARGUMENT, + current_cluster="current_cluster", + active_cluster="active_cluster", + active_clusters=["active_clusters"], + ), + id="EntityNotExistsError", + ), pytest.param( error_pb2.WorkflowExecutionAlreadyCompletedError(), - error.WorkflowExecutionAlreadyCompletedError(message="message", code=StatusCode.INVALID_ARGUMENT), - id="WorkflowExecutionAlreadyCompletedError"), + error.WorkflowExecutionAlreadyCompletedError( + message="message", code=StatusCode.INVALID_ARGUMENT + ), + id="WorkflowExecutionAlreadyCompletedError", + ), pytest.param( - error_pb2.DomainNotActiveError(domain="domain", current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]), - error.DomainNotActiveError(message="message", code=StatusCode.INVALID_ARGUMENT, domain="domain", current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]), - id="DomainNotActiveError"), + error_pb2.DomainNotActiveError( + domain="domain", + current_cluster="current_cluster", + active_cluster="active_cluster", + active_clusters=["active_clusters"], + ), + error.DomainNotActiveError( + message="message", + code=StatusCode.INVALID_ARGUMENT, + domain="domain", + current_cluster="current_cluster", + active_cluster="active_cluster", + active_clusters=["active_clusters"], + ), + id="DomainNotActiveError", + ), pytest.param( - error_pb2.ClientVersionNotSupportedError(feature_version="feature_version", client_impl="client_impl", supported_versions="supported_versions"), - error.ClientVersionNotSupportedError(message="message", code=StatusCode.INVALID_ARGUMENT, feature_version="feature_version", client_impl="client_impl", supported_versions="supported_versions"), - id="ClientVersionNotSupportedError"), + error_pb2.ClientVersionNotSupportedError( + feature_version="feature_version", + client_impl="client_impl", + supported_versions="supported_versions", + ), + error.ClientVersionNotSupportedError( + message="message", + code=StatusCode.INVALID_ARGUMENT, + feature_version="feature_version", + client_impl="client_impl", + supported_versions="supported_versions", + ), + id="ClientVersionNotSupportedError", + ), pytest.param( error_pb2.FeatureNotEnabledError(feature_flag="feature_flag"), - error.FeatureNotEnabledError(message="message", code=StatusCode.INVALID_ARGUMENT,feature_flag="feature_flag"), - id="FeatureNotEnabledError"), + error.FeatureNotEnabledError( + message="message", + code=StatusCode.INVALID_ARGUMENT, + feature_flag="feature_flag", + ), + id="FeatureNotEnabledError", + ), pytest.param( error_pb2.CancellationAlreadyRequestedError(), - error.CancellationAlreadyRequestedError(message="message", code=StatusCode.INVALID_ARGUMENT), - id="CancellationAlreadyRequestedError"), + error.CancellationAlreadyRequestedError( + message="message", code=StatusCode.INVALID_ARGUMENT + ), + id="CancellationAlreadyRequestedError", + ), pytest.param( error_pb2.DomainAlreadyExistsError(), - error.DomainAlreadyExistsError(message="message", code=StatusCode.INVALID_ARGUMENT), - id="DomainAlreadyExistsError"), + error.DomainAlreadyExistsError( + message="message", code=StatusCode.INVALID_ARGUMENT + ), + id="DomainAlreadyExistsError", + ), pytest.param( error_pb2.LimitExceededError(), - error.LimitExceededError(message="message", code=StatusCode.INVALID_ARGUMENT), - id="LimitExceededError"), + error.LimitExceededError( + message="message", code=StatusCode.INVALID_ARGUMENT + ), + id="LimitExceededError", + ), pytest.param( error_pb2.QueryFailedError(), error.QueryFailedError(message="message", code=StatusCode.INVALID_ARGUMENT), - id="QueryFailedError"), + id="QueryFailedError", + ), pytest.param( error_pb2.ServiceBusyError(reason="reason"), - error.ServiceBusyError(message="message", code=StatusCode.INVALID_ARGUMENT, reason="reason"), - id="ServiceBusyError"), + error.ServiceBusyError( + message="message", code=StatusCode.INVALID_ARGUMENT, reason="reason" + ), + id="ServiceBusyError", + ), pytest.param( - to_status(status_pb2.Status(code=code_pb2.PERMISSION_DENIED, message="no permission")), - error.CadenceError(message="no permission", code=StatusCode.PERMISSION_DENIED), - id="unknown error type"), - ] + to_status( + status_pb2.Status( + code=code_pb2.PERMISSION_DENIED, message="no permission" + ) + ), + error.CadenceError( + message="no permission", code=StatusCode.PERMISSION_DENIED + ), + id="unknown error type", + ), + ], ) @pytest.mark.asyncio async def test_map_error(fake_service, err: Message | Status, expected: CadenceError): - async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[CadenceErrorInterceptor()]) as channel: + async with insecure_channel( + f"[::]:{fake_service.port}", interceptors=[CadenceErrorInterceptor()] + ) as channel: stub = service_meta_pb2_grpc.MetaAPIStub(channel) if expected is None: response = await stub.Health(HealthRequest(), timeout=1) @@ -110,6 +184,7 @@ async def test_map_error(fake_service, err: Message | Status, expected: CadenceE await stub.Health(HealthRequest(), timeout=1) assert exc_info.value.args == expected.args + def details_to_status(message: Message) -> Status: detail = any_pb2.Any() detail.Pack(message) @@ -119,4 +194,3 @@ def details_to_status(message: Message) -> Status: details=[detail], ) return to_status(status_proto) - diff --git a/tests/cadence/_internal/rpc/test_retry.py b/tests/cadence/_internal/rpc/test_retry.py index 1874341..aed7f57 100644 --- a/tests/cadence/_internal/rpc/test_retry.py +++ b/tests/cadence/_internal/rpc/test_retry.py @@ -12,42 +12,48 @@ from cadence.api.v1 import error_pb2, service_workflow_pb2_grpc from cadence._internal.rpc.retry import ExponentialRetryPolicy, RetryInterceptor -from cadence.api.v1.service_workflow_pb2 import DescribeWorkflowExecutionResponse, \ - DescribeWorkflowExecutionRequest, GetWorkflowExecutionHistoryRequest +from cadence.api.v1.service_workflow_pb2 import ( + DescribeWorkflowExecutionResponse, + DescribeWorkflowExecutionRequest, + GetWorkflowExecutionHistoryRequest, +) from cadence.error import CadenceError, FeatureNotEnabledError, EntityNotExistsError -simple_policy = ExponentialRetryPolicy(initial_interval=1, backoff_coefficient=2, max_interval=10, max_attempts=6) +simple_policy = ExponentialRetryPolicy( + initial_interval=1, backoff_coefficient=2, max_interval=10, max_attempts=6 +) + @pytest.mark.parametrize( "policy,params,expected", [ - pytest.param( - simple_policy, (1, 0.0, 100.0), 1, id="happy path" - ), - pytest.param( - simple_policy, (2, 0.0, 100.0), 2, id="second attempt" - ), - pytest.param( - simple_policy, (3, 0.0, 100.0), 4, id="third attempt" - ), - pytest.param( - simple_policy, (5, 0.0, 100.0), 10, id="capped by max_interval" - ), - pytest.param( - simple_policy, (6, 0.0, 100.0), None, id="out of attempts" - ), - pytest.param( - simple_policy, (1, 100.0, 100.0), None, id="timeout" - ), + pytest.param(simple_policy, (1, 0.0, 100.0), 1, id="happy path"), + pytest.param(simple_policy, (2, 0.0, 100.0), 2, id="second attempt"), + pytest.param(simple_policy, (3, 0.0, 100.0), 4, id="third attempt"), + pytest.param(simple_policy, (5, 0.0, 100.0), 10, id="capped by max_interval"), + pytest.param(simple_policy, (6, 0.0, 100.0), None, id="out of attempts"), + pytest.param(simple_policy, (1, 100.0, 100.0), None, id="timeout"), pytest.param( simple_policy, (1, 99.0, 100.0), None, id="backoff causes timeout" ), pytest.param( - ExponentialRetryPolicy(initial_interval=1, backoff_coefficient=1, max_interval=10, max_attempts=0), (100, 0.0, 100.0), 1, id="unlimited retries" + ExponentialRetryPolicy( + initial_interval=1, + backoff_coefficient=1, + max_interval=10, + max_attempts=0, + ), + (100, 0.0, 100.0), + 1, + id="unlimited retries", ), - ] + ], ) -def test_next_delay(policy: ExponentialRetryPolicy, params: Tuple[int, float, float], expected: float | None): +def test_next_delay( + policy: ExponentialRetryPolicy, + params: Tuple[int, float, float], + expected: float | None, +): assert policy.next_delay(*params) == expected @@ -58,22 +64,29 @@ def __init__(self) -> None: self.counter = 0 # Retryable only because it's GetWorkflowExecutionHistory - def GetWorkflowExecutionHistory(self, request: GetWorkflowExecutionHistoryRequest, context): + def GetWorkflowExecutionHistory( + self, request: GetWorkflowExecutionHistoryRequest, context + ): self.counter += 1 detail = any_pb2.Any() - detail.Pack(error_pb2.EntityNotExistsError(current_cluster=request.domain, active_cluster="active")) + detail.Pack( + error_pb2.EntityNotExistsError( + current_cluster=request.domain, active_cluster="active" + ) + ) status_proto = status_pb2.Status( code=code_pb2.NOT_FOUND, message="message", details=[detail], ) context.abort_with_status(to_status(status_proto)) - # Unreachable - + # Unreachable # Not retryable - def DescribeWorkflowExecution(self, request: DescribeWorkflowExecutionRequest, context): + def DescribeWorkflowExecution( + self, request: DescribeWorkflowExecutionRequest, context + ): self.counter += 1 if request.domain == "success": @@ -109,59 +122,69 @@ def fake_service(): yield fake sync_server.stop(grace=None) -TEST_POLICY = ExponentialRetryPolicy(initial_interval=0, backoff_coefficient=0, max_interval=10, max_attempts=10) + +TEST_POLICY = ExponentialRetryPolicy( + initial_interval=0, backoff_coefficient=0, max_interval=10, max_attempts=10 +) + @pytest.mark.usefixtures("fake_service") @pytest.mark.parametrize( "case,expected_calls,expected_err", [ + pytest.param("success", 1, None, id="happy path"), + pytest.param("maybe later", 3, None, id="retries then success"), + pytest.param("not retryable", 1, FeatureNotEnabledError, id="not retryable"), pytest.param( - "success", 1, None, id="happy path" - ), - pytest.param( - "maybe later", 3, None, id="retries then success" - ), - pytest.param( - "not retryable", 1, FeatureNotEnabledError, id="not retryable" - ), - pytest.param( - "retryable", TEST_POLICY.max_attempts, FeatureNotEnabledError, id="retries exhausted" + "retryable", + TEST_POLICY.max_attempts, + FeatureNotEnabledError, + id="retries exhausted", ), - - ] + ], ) @pytest.mark.asyncio -async def test_retryable_error(fake_service, case: str, expected_calls: int, expected_err: Type[CadenceError]): +async def test_retryable_error( + fake_service, case: str, expected_calls: int, expected_err: Type[CadenceError] +): fake_service.counter = 0 - async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()]) as channel: + async with insecure_channel( + f"[::]:{fake_service.port}", + interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()], + ) as channel: stub = service_workflow_pb2_grpc.WorkflowAPIStub(channel) if expected_err: with pytest.raises(expected_err): - await stub.DescribeWorkflowExecution(DescribeWorkflowExecutionRequest(domain=case), timeout=10) + await stub.DescribeWorkflowExecution( + DescribeWorkflowExecutionRequest(domain=case), timeout=10 + ) else: - await stub.DescribeWorkflowExecution(DescribeWorkflowExecutionRequest(domain=case), timeout=10) + await stub.DescribeWorkflowExecution( + DescribeWorkflowExecutionRequest(domain=case), timeout=10 + ) assert fake_service.counter == expected_calls + @pytest.mark.usefixtures("fake_service") @pytest.mark.parametrize( "case,expected_calls", [ - pytest.param( - "active", 1, id="not retryable" - ), - pytest.param( - "not active", TEST_POLICY.max_attempts, id="retries exhausted" - ), - - ] + pytest.param("active", 1, id="not retryable"), + pytest.param("not active", TEST_POLICY.max_attempts, id="retries exhausted"), + ], ) @pytest.mark.asyncio async def test_workflow_history(fake_service, case: str, expected_calls: int): fake_service.counter = 0 - async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()]) as channel: + async with insecure_channel( + f"[::]:{fake_service.port}", + interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()], + ) as channel: stub = service_workflow_pb2_grpc.WorkflowAPIStub(channel) with pytest.raises(EntityNotExistsError): - await stub.GetWorkflowExecutionHistory(GetWorkflowExecutionHistoryRequest(domain=case), timeout=10) + await stub.GetWorkflowExecutionHistory( + GetWorkflowExecutionHistoryRequest(domain=case), timeout=10 + ) - assert fake_service.counter == expected_calls \ No newline at end of file + assert fake_service.counter == expected_calls diff --git a/tests/cadence/_internal/test_decision_state_machine.py b/tests/cadence/_internal/test_decision_state_machine.py index 4f61dca..e967c2f 100644 --- a/tests/cadence/_internal/test_decision_state_machine.py +++ b/tests/cadence/_internal/test_decision_state_machine.py @@ -38,7 +38,7 @@ def test_timer_state_machine_cancel_after_initiated(): timer_id="t-cai" ), ), - "initiated" + "initiated", ) m.request_cancel() d = m.collect_pending_decisions() @@ -57,7 +57,7 @@ def test_timer_state_machine_completed_after_cancel(): timer_id="t-cac" ), ), - "initiated" + "initiated", ) m.request_cancel() _ = m.collect_pending_decisions() @@ -68,7 +68,7 @@ def test_timer_state_machine_completed_after_cancel(): timer_id="t-cac", started_event_id=2 ), ), - "completion" + "completion", ) assert m.status is DecisionState.COMPLETED @@ -85,7 +85,7 @@ def test_timer_state_machine_complete_without_cancel(): timer_id="t-cwc" ), ), - "initiated" + "initiated", ) m.handle_event( history.HistoryEvent( @@ -94,14 +94,11 @@ def test_timer_state_machine_complete_without_cancel(): timer_id="t-cwc", started_event_id=4 ), ), - "completion" + "completion", ) assert m.status is DecisionState.COMPLETED - - - @pytest.mark.unit def test_timer_cancel_event_ordering(): attrs = decision.StartTimerDecisionAttributes(timer_id="t-ord") @@ -114,7 +111,7 @@ def test_timer_cancel_event_ordering(): timer_id="t-ord" ), ), - "initiated" + "initiated", ) m.request_cancel() d1 = m.collect_pending_decisions() @@ -127,7 +124,7 @@ def test_timer_cancel_event_ordering(): timer_id="t-ord" ), ), - "cancel_failed" + "cancel_failed", ) d2 = m.collect_pending_decisions() assert len(d2) == 1 and d2[0].HasField("cancel_timer_decision_attributes") @@ -146,7 +143,7 @@ def test_activity_state_machine_complete_without_cancel(): activity_id="act-1" ), ), - "initiated" + "initiated", ) m.handle_event( history.HistoryEvent( @@ -155,7 +152,7 @@ def test_activity_state_machine_complete_without_cancel(): scheduled_event_id=20 ), ), - "started" + "started", ) m.handle_event( history.HistoryEvent( @@ -164,7 +161,7 @@ def test_activity_state_machine_complete_without_cancel(): scheduled_event_id=20, started_event_id=21 ), ), - "completion" + "completion", ) assert m.status is DecisionState.COMPLETED @@ -205,7 +202,7 @@ def test_activity_state_machine_completed_after_cancel(): activity_id="act-cac" ), ), - "initiated" + "initiated", ) m.handle_event( history.HistoryEvent( @@ -214,7 +211,7 @@ def test_activity_state_machine_completed_after_cancel(): scheduled_event_id=30 ), ), - "started" + "started", ) m.request_cancel() _ = m.collect_pending_decisions() @@ -225,14 +222,11 @@ def test_activity_state_machine_completed_after_cancel(): scheduled_event_id=30, started_event_id=31 ), ), - "completion" + "completion", ) assert m.status is DecisionState.COMPLETED - - - @pytest.mark.unit def test_child_workflow_state_machine_basic(): attrs = decision.StartChildWorkflowExecutionDecisionAttributes( @@ -250,7 +244,7 @@ def test_child_workflow_state_machine_basic(): domain="d1", workflow_id="wf-1" ), ), - "initiated" + "initiated", ) m.handle_event( history.HistoryEvent( @@ -259,7 +253,7 @@ def test_child_workflow_state_machine_basic(): initiated_event_id=40 ), ), - "started" + "started", ) m.handle_event( history.HistoryEvent( @@ -268,7 +262,7 @@ def test_child_workflow_state_machine_basic(): initiated_event_id=40 ), ), - "completion" + "completion", ) assert m.status is DecisionState.COMPLETED @@ -287,7 +281,7 @@ def test_child_workflow_state_machine_cancel_succeed(): domain="d2", workflow_id="wf-2" ), ), - "initiated" + "initiated", ) m.request_cancel() d = m.collect_pending_decisions() @@ -301,14 +295,11 @@ def test_child_workflow_state_machine_cancel_succeed(): initiated_event_id=50 ), ), - "canceled" + "canceled", ) assert m.status is DecisionState.CANCELED_AFTER_INITIATED - - - @pytest.mark.unit @pytest.mark.skip("External cancel failure event handling is not implemented") def test_child_workflow_state_machine_cancel_failed(): diff --git a/tests/cadence/_internal/workflow/test_decision_events_iterator.py b/tests/cadence/_internal/workflow/test_decision_events_iterator.py index dd58fbc..94edef9 100644 --- a/tests/cadence/_internal/workflow/test_decision_events_iterator.py +++ b/tests/cadence/_internal/workflow/test_decision_events_iterator.py @@ -19,20 +19,22 @@ DecisionEventsIterator, is_decision_event, is_marker_event, - extract_event_timestamp_millis + extract_event_timestamp_millis, ) -def create_mock_history_event(event_id: int, event_type: str, timestamp_seconds: int = 1000) -> HistoryEvent: +def create_mock_history_event( + event_id: int, event_type: str, timestamp_seconds: int = 1000 +) -> HistoryEvent: """Create a mock history event for testing.""" event = HistoryEvent() event.event_id = event_id - + # Create proper protobuf timestamp timestamp = Timestamp() timestamp.seconds = timestamp_seconds event.event_time.CopyFrom(timestamp) - + # Set the appropriate attribute based on event type if event_type == "decision_task_started": event.decision_task_started_event_attributes.SetInParent() @@ -44,28 +46,30 @@ def create_mock_history_event(event_id: int, event_type: str, timestamp_seconds: event.decision_task_timed_out_event_attributes.SetInParent() elif event_type == "marker_recorded": event.marker_recorded_event_attributes.SetInParent() - + return event -def create_mock_decision_task(events: List[HistoryEvent], next_page_token: bytes = None) -> PollForDecisionTaskResponse: +def create_mock_decision_task( + events: List[HistoryEvent], next_page_token: bytes = None +) -> PollForDecisionTaskResponse: """Create a mock decision task for testing.""" task = PollForDecisionTaskResponse() - + # Mock history history = History() history.events.extend(events) task.history.CopyFrom(history) - + # Mock workflow execution workflow_execution = WorkflowExecution() workflow_execution.workflow_id = "test-workflow" workflow_execution.run_id = "test-run" task.workflow_execution.CopyFrom(workflow_execution) - + if next_page_token: task.next_page_token = next_page_token - + return task @@ -81,80 +85,84 @@ def mock_client(): class TestDecisionEvents: """Test the DecisionEvents class.""" - + def test_decision_events_initialization(self): """Test DecisionEvents initialization.""" decision_events = DecisionEvents() - + assert decision_events.get_events() == [] assert decision_events.get_markers() == [] assert not decision_events.is_replay() assert decision_events.replay_current_time_milliseconds is None assert decision_events.next_decision_event_id is None - + def test_decision_events_with_data(self): """Test DecisionEvents with actual data.""" - events = [create_mock_history_event(1, "decision_task_started"), create_mock_history_event(2, "decision_task_completed")] + events = [ + create_mock_history_event(1, "decision_task_started"), + create_mock_history_event(2, "decision_task_completed"), + ] markers = [create_mock_history_event(3, "marker_recorded")] - + decision_events_obj = DecisionEvents( events=events, markers=markers, replay=True, replay_current_time_milliseconds=123456, - next_decision_event_id=4 + next_decision_event_id=4, ) - + assert decision_events_obj.get_events() == events assert decision_events_obj.get_markers() == markers assert decision_events_obj.is_replay() assert decision_events_obj.replay_current_time_milliseconds == 123456 assert decision_events_obj.next_decision_event_id == 4 - + def test_get_event_by_id(self): """Test retrieving event by ID.""" event1 = create_mock_history_event(1, "decision_task_started") event2 = create_mock_history_event(2, "decision_task_completed") - + decision_events = DecisionEvents(events=[event1, event2]) - + assert decision_events.get_event_by_id(1) == event1 assert decision_events.get_event_by_id(2) == event2 assert decision_events.get_event_by_id(999) is None - class TestDecisionEventsIterator: """Test the DecisionEventsIterator class.""" - + @pytest.mark.asyncio async def test_single_decision_iteration(self, mock_client): """Test processing a single decision iteration.""" # Create events for a complete decision iteration events = [ create_mock_history_event(1, "decision_task_started", 1000), - create_mock_history_event(2, "activity_scheduled", 1001), # Some workflow event + create_mock_history_event( + 2, "activity_scheduled", 1001 + ), # Some workflow event create_mock_history_event(3, "marker_recorded", 1002), - create_mock_history_event(4, "decision_task_completed", 1003) + create_mock_history_event(4, "decision_task_completed", 1003), ] - + decision_task = create_mock_decision_task(events) iterator = DecisionEventsIterator(decision_task, mock_client) await iterator._ensure_initialized() - + assert await iterator.has_next_decision_events() - + decision_events = await iterator.next_decision_events() - + assert len(decision_events.get_events()) == 4 assert len(decision_events.get_markers()) == 1 assert decision_events.get_markers()[0].event_id == 3 # In this test scenario with only one decision iteration, replay gets set to false - # when we determine there are no more decision events after this one + # when we determine there are no more decision events after this one # This matches the Java client behavior where the last decision events have replay=false assert not decision_events.is_replay() assert decision_events.replay_current_time_milliseconds == 1000 * 1000 - + @pytest.mark.asyncio async def test_multiple_decision_iterations(self, mock_client): """Test processing multiple decision iterations.""" @@ -165,66 +173,68 @@ async def test_multiple_decision_iterations(self, mock_client): create_mock_history_event(2, "decision_task_completed", 1001), # Second iteration create_mock_history_event(3, "decision_task_started", 1002), - create_mock_history_event(4, "decision_task_completed", 1003) + create_mock_history_event(4, "decision_task_completed", 1003), ] - + decision_task = create_mock_decision_task(events) iterator = DecisionEventsIterator(decision_task, mock_client) await iterator._ensure_initialized() - + # First iteration assert await iterator.has_next_decision_events() first_decision = await iterator.next_decision_events() assert len(first_decision.get_events()) == 2 assert first_decision.get_events()[0].event_id == 1 - + # Second iteration assert await iterator.has_next_decision_events() second_decision = await iterator.next_decision_events() assert len(second_decision.get_events()) == 2 assert second_decision.get_events()[0].event_id == 3 - + # No more iterations assert not await iterator.has_next_decision_events() - + @pytest.mark.asyncio async def test_pagination_support(self, mock_client): """Test that pagination is handled correctly.""" # First page events first_page_events = [ create_mock_history_event(1, "decision_task_started"), - create_mock_history_event(2, "decision_task_completed") + create_mock_history_event(2, "decision_task_completed"), ] - + # Second page events second_page_events = [ create_mock_history_event(3, "decision_task_started"), - create_mock_history_event(4, "decision_task_completed") + create_mock_history_event(4, "decision_task_completed"), ] - + # Mock the pagination response pagination_response = GetWorkflowExecutionHistoryResponse() pagination_history = History() pagination_history.events.extend(second_page_events) pagination_response.history.CopyFrom(pagination_history) pagination_response.next_page_token = b"" # No more pages - - mock_client.workflow_stub.GetWorkflowExecutionHistory.return_value = pagination_response - - # Create decision task with next page token + + mock_client.workflow_stub.GetWorkflowExecutionHistory.return_value = ( + pagination_response + ) + + # Create decision task with next page token decision_task = create_mock_decision_task(first_page_events, b"next-page-token") iterator = DecisionEventsIterator(decision_task, mock_client) await iterator._ensure_initialized() - + # Should process both pages iterations_count = 0 while await iterator.has_next_decision_events(): await iterator.next_decision_events() iterations_count += 1 - + assert iterations_count == 2 assert mock_client.workflow_stub.GetWorkflowExecutionHistory.called - + @pytest.mark.asyncio async def test_iterator_protocol(self, mock_client): """Test that DecisionEventsIterator works with Python iterator protocol.""" @@ -232,48 +242,48 @@ async def test_iterator_protocol(self, mock_client): create_mock_history_event(1, "decision_task_started"), create_mock_history_event(2, "decision_task_completed"), create_mock_history_event(3, "decision_task_started"), - create_mock_history_event(4, "decision_task_completed") + create_mock_history_event(4, "decision_task_completed"), ] - + decision_task = create_mock_decision_task(events) iterator = DecisionEventsIterator(decision_task, mock_client) await iterator._ensure_initialized() - + decision_events_list = [] async for decision_events in iterator: decision_events_list.append(decision_events) - - assert len(decision_events_list) == 2 - + assert len(decision_events_list) == 2 class TestUtilityFunctions: """Test utility functions.""" - + def test_is_decision_event(self): """Test is_decision_event utility function.""" decision_event = create_mock_history_event(1, "decision_task_started") - non_decision_event = create_mock_history_event(2, "activity_scheduled") # Random event type - + non_decision_event = create_mock_history_event( + 2, "activity_scheduled" + ) # Random event type + assert is_decision_event(decision_event) assert not is_decision_event(non_decision_event) - + def test_is_marker_event(self): """Test is_marker_event utility function.""" marker_event = create_mock_history_event(1, "marker_recorded") non_marker_event = create_mock_history_event(2, "decision_task_started") - + assert is_marker_event(marker_event) assert not is_marker_event(non_marker_event) - + def test_extract_event_timestamp_millis(self): """Test extract_event_timestamp_millis utility function.""" event = create_mock_history_event(1, "some_event", 1234) - + timestamp_millis = extract_event_timestamp_millis(event) assert timestamp_millis == 1234 * 1000 - + # Test event without timestamp event_no_timestamp = HistoryEvent() assert extract_event_timestamp_millis(event_no_timestamp) is None @@ -281,7 +291,7 @@ def test_extract_event_timestamp_millis(self): class TestIntegrationScenarios: """Test real-world integration scenarios.""" - + @pytest.mark.asyncio async def test_replay_detection(self, mock_client): """Test replay mode detection.""" @@ -291,23 +301,23 @@ async def test_replay_detection(self, mock_client): create_mock_history_event(2, "decision_task_completed"), create_mock_history_event(3, "decision_task_started"), # Current decision ] - + decision_task = create_mock_decision_task(events) # Mock the started_event_id to indicate current decision decision_task.started_event_id = 3 - + iterator = DecisionEventsIterator(decision_task, mock_client) await iterator._ensure_initialized() - + # First decision should be replay (but gets set to false when no more events) await iterator.next_decision_events() # Since this test has incomplete events (no completion for the third decision), # the replay logic may behave differently # assert first_decision.is_replay() - + # When we get to current decision, replay should be false # (This would need the completion event to trigger the replay mode change) - + @pytest.mark.asyncio async def test_complex_workflow_scenario(self, mock_client): """Test a complex workflow with multiple event types.""" @@ -319,24 +329,24 @@ async def test_complex_workflow_scenario(self, mock_client): create_mock_history_event(5, "activity_completed"), # Activity completed create_mock_history_event(6, "decision_task_completed"), create_mock_history_event(7, "decision_task_started"), - create_mock_history_event(8, "decision_task_completed") + create_mock_history_event(8, "decision_task_completed"), ] - + decision_task = create_mock_decision_task(events) iterator = DecisionEventsIterator(decision_task, mock_client) - + all_decisions = [] async for decision_events in iterator: all_decisions.append(decision_events) - + assert len(all_decisions) == 2 - + # First decision should have more events including markers first_decision = all_decisions[0] assert len(first_decision.get_events()) == 6 # Events 1-6 assert len(first_decision.get_markers()) == 1 # Event 4 - + # Second decision should be simpler second_decision = all_decisions[1] assert len(second_decision.get_events()) == 2 # Events 7-8 - assert len(second_decision.get_markers()) == 0 \ No newline at end of file + assert len(second_decision.get_markers()) == 0 diff --git a/tests/cadence/_internal/workflow/test_deterministic_event_loop.py b/tests/cadence/_internal/workflow/test_deterministic_event_loop.py index 9d2e7fb..d555148 100644 --- a/tests/cadence/_internal/workflow/test_deterministic_event_loop.py +++ b/tests/cadence/_internal/workflow/test_deterministic_event_loop.py @@ -6,21 +6,25 @@ async def coro_append(results: list, i: int): results.append(i) + async def coro_await(size: int): results = [] for i in range(size): await coro_append(results, i) return results + async def coro_await_future(future: asyncio.Future): return await future + async def coro_await_task(size: int): results = [] for i in range(size): asyncio.create_task(coro_append(results, i)) return results + class TestDeterministicEventLoop: """Test suite for DeterministicEventLoop using table-driven tests.""" @@ -54,8 +58,10 @@ def test_run_until_complete(self): assert self.loop.is_running() is False assert self.loop.is_closed() is False - @pytest.mark.parametrize("result, exception, expected, expected_exception", - [(10000, None, 10000, None), (None, ValueError("test"), None, ValueError)]) + @pytest.mark.parametrize( + "result, exception, expected, expected_exception", + [(10000, None, 10000, None), (None, ValueError("test"), None, ValueError)], + ) def test_create_future(self, result, exception, expected, expected_exception): future = self.loop.create_future() if expected_exception is not None: diff --git a/tests/cadence/_internal/workflow/test_history_event_iterator.py b/tests/cadence/_internal/workflow/test_history_event_iterator.py index 430f4cf..b1c9b22 100644 --- a/tests/cadence/_internal/workflow/test_history_event_iterator.py +++ b/tests/cadence/_internal/workflow/test_history_event_iterator.py @@ -21,34 +21,31 @@ def mock_client(): @pytest.fixture def mock_workflow_execution(): """Create a mock workflow execution.""" - return WorkflowExecution( - workflow_id="test-workflow-id", - run_id="test-run-id" - ) + return WorkflowExecution(workflow_id="test-workflow-id", run_id="test-run-id") def create_history_event(event_id: int) -> HistoryEvent: return HistoryEvent(event_id=event_id) -async def test_iterate_history_events_single_page_no_next_token(mock_client, mock_workflow_execution): +async def test_iterate_history_events_single_page_no_next_token( + mock_client, mock_workflow_execution +): """Test iterating over a single page of events with no next page token.""" # Create test events - events = [ - create_history_event(1), - create_history_event(2), - create_history_event(3) - ] + events = [create_history_event(1), create_history_event(2), create_history_event(3)] # Create decision task response with events but no next page token decision_task = PollForDecisionTaskResponse( history=History(events=events), next_page_token=b"", # Empty token means no more pages - workflow_execution=mock_workflow_execution + workflow_execution=mock_workflow_execution, ) # Iterate and collect events - result_events = [e async for e in iterate_history_events(decision_task, mock_client)] + result_events = [ + e async for e in iterate_history_events(decision_task, mock_client) + ] # Verify all events were returned assert len(result_events) == 3 @@ -60,17 +57,21 @@ async def test_iterate_history_events_single_page_no_next_token(mock_client, moc mock_client.workflow_stub.GetWorkflowExecutionHistory.assert_not_called() -async def test_iterate_history_events_empty_events(mock_client, mock_workflow_execution): +async def test_iterate_history_events_empty_events( + mock_client, mock_workflow_execution +): """Test iterating over empty events list.""" # Create decision task response with no events decision_task = PollForDecisionTaskResponse( history=History(events=[]), next_page_token=b"", - workflow_execution=mock_workflow_execution + workflow_execution=mock_workflow_execution, ) # Iterate and collect events - result_events = [e async for e in iterate_history_events(decision_task, mock_client)] + result_events = [ + e async for e in iterate_history_events(decision_task, mock_client) + ] # Verify no events were returned assert len(result_events) == 0 @@ -78,43 +79,40 @@ async def test_iterate_history_events_empty_events(mock_client, mock_workflow_ex # Verify no additional API calls were made mock_client.workflow_stub.GetWorkflowExecutionHistory.assert_not_called() -async def test_iterate_history_events_multiple_pages(mock_client, mock_workflow_execution): + +async def test_iterate_history_events_multiple_pages( + mock_client, mock_workflow_execution +): """Test iterating over multiple pages of events.""" # Create decision task response with first page and next page token decision_task = PollForDecisionTaskResponse( - history=History(events=[ - create_history_event(1), - create_history_event(2) - ]), + history=History(events=[create_history_event(1), create_history_event(2)]), next_page_token=b"page2_token", - workflow_execution=mock_workflow_execution + workflow_execution=mock_workflow_execution, ) # Mock the subsequent API calls second_response = GetWorkflowExecutionHistoryResponse( - history=History(events=[ - create_history_event(3), - create_history_event(4) - ]), - next_page_token=b"page3_token" + history=History(events=[create_history_event(3), create_history_event(4)]), + next_page_token=b"page3_token", ) third_response = GetWorkflowExecutionHistoryResponse( - history=History(events=[ - create_history_event(5) - ]), - next_page_token=b"" # No more pages + history=History(events=[create_history_event(5)]), + next_page_token=b"", # No more pages ) # Configure mock to return responses in sequence mock_client.workflow_stub.GetWorkflowExecutionHistory.side_effect = [ second_response, - third_response + third_response, ] # Iterate and collect events - result_events = [e async for e in iterate_history_events(decision_task, mock_client)] + result_events = [ + e async for e in iterate_history_events(decision_task, mock_client) + ] # Verify all events from all pages were returned assert len(result_events) == 5 @@ -136,38 +134,42 @@ async def test_iterate_history_events_multiple_pages(mock_client, mock_workflow_ assert first_request.page_size == 1000 # Verify second API call - second_call = mock_client.workflow_stub.GetWorkflowExecutionHistory.call_args_list[1] + second_call = mock_client.workflow_stub.GetWorkflowExecutionHistory.call_args_list[ + 1 + ] second_request = second_call[0][0] assert second_request.domain == "test-domain" assert second_request.workflow_execution == mock_workflow_execution assert second_request.next_page_token == b"page3_token" assert second_request.page_size == 1000 -async def test_iterate_history_events_single_page_with_next_token_then_empty(mock_client, mock_workflow_execution): + +async def test_iterate_history_events_single_page_with_next_token_then_empty( + mock_client, mock_workflow_execution +): """Test case where first page has next token but second page is empty.""" # Create first page of events - first_page_events = [ - create_history_event(1), - create_history_event(2) - ] + first_page_events = [create_history_event(1), create_history_event(2)] # Create decision task response with first page and next page token decision_task = PollForDecisionTaskResponse( history=History(events=first_page_events), next_page_token=b"page2_token", - workflow_execution=mock_workflow_execution + workflow_execution=mock_workflow_execution, ) # Mock the second API call to return empty page second_response = GetWorkflowExecutionHistoryResponse( history=History(events=[]), - next_page_token=b"" # No more pages + next_page_token=b"", # No more pages ) mock_client.workflow_stub.GetWorkflowExecutionHistory.return_value = second_response # Iterate and collect events - result_events = [e async for e in iterate_history_events(decision_task, mock_client)] + result_events = [ + e async for e in iterate_history_events(decision_task, mock_client) + ] # Verify only first page events were returned assert len(result_events) == 2 diff --git a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py index 3805f56..768768e 100644 --- a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py +++ b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py @@ -7,7 +7,11 @@ from unittest.mock import Mock, AsyncMock, patch from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse from cadence.api.v1.common_pb2 import Payload, WorkflowExecution, WorkflowType -from cadence.api.v1.history_pb2 import History, HistoryEvent, WorkflowExecutionStartedEventAttributes +from cadence.api.v1.history_pb2 import ( + History, + HistoryEvent, + WorkflowExecutionStartedEventAttributes, +) from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult from cadence import workflow from cadence.workflow import WorkflowInfo, WorkflowDefinition, WorkflowDefinitionOptions @@ -33,12 +37,13 @@ def workflow_info(self): workflow_type="test_workflow", workflow_domain="test-domain", workflow_id="test-workflow-id", - workflow_run_id="test-run-id" + workflow_run_id="test-run-id", ) @pytest.fixture def mock_workflow_definition(self): """Create a mock workflow definition.""" + class TestWorkflow: @workflow.run async def weird_name(self, input_data): @@ -53,51 +58,62 @@ def workflow_engine(self, mock_client, workflow_info, mock_workflow_definition): return WorkflowEngine( info=workflow_info, client=mock_client, - workflow_definition=mock_workflow_definition + workflow_definition=mock_workflow_definition, ) - def create_mock_decision_task(self, workflow_id="test-workflow", run_id="test-run", workflow_type="test_workflow"): + def create_mock_decision_task( + self, + workflow_id="test-workflow", + run_id="test-run", + workflow_type="test_workflow", + ): """Create a mock decision task with history.""" # Create workflow execution workflow_execution = WorkflowExecution() workflow_execution.workflow_id = workflow_id workflow_execution.run_id = run_id - + # Create workflow type workflow_type_obj = WorkflowType() workflow_type_obj.name = workflow_type - + # Create workflow execution started event started_event = WorkflowExecutionStartedEventAttributes() input_payload = Payload(data=b'"test-input"') started_event.input.CopyFrom(input_payload) - + history_event = HistoryEvent() - history_event.workflow_execution_started_event_attributes.CopyFrom(started_event) - + history_event.workflow_execution_started_event_attributes.CopyFrom( + started_event + ) + # Create history history = History() history.events.append(history_event) - + # Create decision task decision_task = PollForDecisionTaskResponse() decision_task.task_token = b"test-task-token" decision_task.workflow_execution.CopyFrom(workflow_execution) decision_task.workflow_type.CopyFrom(workflow_type_obj) decision_task.history.CopyFrom(history) - + return decision_task @pytest.mark.asyncio async def test_process_decision_success(self, workflow_engine, mock_client): """Test successful decision processing.""" decision_task = self.create_mock_decision_task() - + # Mock the decision manager to return some decisions - with patch.object(workflow_engine._decision_manager, 'collect_pending_decisions', return_value=[Mock()]): + with patch.object( + workflow_engine._decision_manager, + "collect_pending_decisions", + return_value=[Mock()], + ): # Process the decision result = await workflow_engine.process_decision(decision_task) - + # Verify the result assert isinstance(result, DecisionResult) assert len(result.decisions) == 1 @@ -106,28 +122,40 @@ async def test_process_decision_success(self, workflow_engine, mock_client): async def test_process_decision_with_history(self, workflow_engine, mock_client): """Test decision processing with history events.""" decision_task = self.create_mock_decision_task() - + # Mock the decision manager - with patch.object(workflow_engine._decision_manager, 'handle_history_event') as mock_handle: - with patch.object(workflow_engine._decision_manager, 'collect_pending_decisions', return_value=[]): + with patch.object( + workflow_engine._decision_manager, "handle_history_event" + ) as mock_handle: + with patch.object( + workflow_engine._decision_manager, + "collect_pending_decisions", + return_value=[], + ): # Process the decision await workflow_engine.process_decision(decision_task) - + # Verify history events were processed mock_handle.assert_called() @pytest.mark.asyncio - async def test_process_decision_workflow_complete(self, workflow_engine, mock_client): + async def test_process_decision_workflow_complete( + self, workflow_engine, mock_client + ): """Test decision processing when workflow is already complete.""" # Mark workflow as complete workflow_engine._is_workflow_complete = True - + decision_task = self.create_mock_decision_task() - - with patch.object(workflow_engine._decision_manager, 'collect_pending_decisions', return_value=[]): + + with patch.object( + workflow_engine._decision_manager, + "collect_pending_decisions", + return_value=[], + ): # Process the decision result = await workflow_engine.process_decision(decision_task) - + # Verify the result assert isinstance(result, DecisionResult) assert len(result.decisions) == 0 @@ -136,12 +164,16 @@ async def test_process_decision_workflow_complete(self, workflow_engine, mock_cl async def test_process_decision_error_handling(self, workflow_engine, mock_client): """Test decision processing error handling.""" decision_task = self.create_mock_decision_task() - + # Mock the decision manager to raise an exception - with patch.object(workflow_engine._decision_manager, 'handle_history_event', side_effect=Exception("Test error")): + with patch.object( + workflow_engine._decision_manager, + "handle_history_event", + side_effect=Exception("Test error"), + ): # Process the decision result = await workflow_engine.process_decision(decision_task) - + # Verify error handling - should return empty decisions assert isinstance(result, DecisionResult) assert len(result.decisions) == 0 @@ -150,66 +182,74 @@ async def test_process_decision_error_handling(self, workflow_engine, mock_clien async def test_extract_workflow_input_success(self, workflow_engine, mock_client): """Test successful workflow input extraction.""" decision_task = self.create_mock_decision_task() - + # Extract workflow input input_data = await workflow_engine._extract_workflow_input(decision_task) - + # Verify the input was extracted assert input_data == "test-input" mock_client.data_converter.from_data.assert_called_once() @pytest.mark.asyncio - async def test_extract_workflow_input_no_history(self, workflow_engine, mock_client): + async def test_extract_workflow_input_no_history( + self, workflow_engine, mock_client + ): """Test workflow input extraction with no history.""" decision_task = PollForDecisionTaskResponse() decision_task.task_token = b"test-task-token" # No history set - + # Extract workflow input input_data = await workflow_engine._extract_workflow_input(decision_task) - + # Verify no input was extracted assert input_data is None @pytest.mark.asyncio - async def test_extract_workflow_input_no_started_event(self, workflow_engine, mock_client): + async def test_extract_workflow_input_no_started_event( + self, workflow_engine, mock_client + ): """Test workflow input extraction with no WorkflowExecutionStarted event.""" # Create a decision task with no started event decision_task = PollForDecisionTaskResponse() decision_task.task_token = b"test-task-token" - + # Create workflow execution workflow_execution = WorkflowExecution() workflow_execution.workflow_id = "test-workflow" workflow_execution.run_id = "test-run" decision_task.workflow_execution.CopyFrom(workflow_execution) - + # Create workflow type workflow_type_obj = WorkflowType() workflow_type_obj.name = "test_workflow" decision_task.workflow_type.CopyFrom(workflow_type_obj) - + # Create history with no events history = History() decision_task.history.CopyFrom(history) - + # Extract workflow input input_data = await workflow_engine._extract_workflow_input(decision_task) - + # Verify no input was extracted assert input_data is None @pytest.mark.asyncio - async def test_extract_workflow_input_deserialization_error(self, workflow_engine, mock_client): + async def test_extract_workflow_input_deserialization_error( + self, workflow_engine, mock_client + ): """Test workflow input extraction with deserialization error.""" decision_task = self.create_mock_decision_task() - + # Mock data converter to raise an exception - mock_client.data_converter.from_data = AsyncMock(side_effect=Exception("Deserialization error")) - + mock_client.data_converter.from_data = AsyncMock( + side_effect=Exception("Deserialization error") + ) + # Extract workflow input input_data = await workflow_engine._extract_workflow_input(decision_task) - + # Verify no input was extracted due to error assert input_data is None @@ -219,10 +259,14 @@ async def test_execute_workflow_function_sync(self, workflow_engine): input_data = "test-input" # Get the workflow function from the instance - workflow_func = workflow_engine._workflow_definition.get_run_method(workflow_engine._workflow_instance) + workflow_func = workflow_engine._workflow_definition.get_run_method( + workflow_engine._workflow_instance + ) # Execute the workflow function - result = await workflow_engine._execute_workflow_function_once(workflow_func, input_data) + result = await workflow_engine._execute_workflow_function_once( + workflow_func, input_data + ) # Verify the result assert result == "processed: test-input" @@ -230,14 +274,17 @@ async def test_execute_workflow_function_sync(self, workflow_engine): @pytest.mark.asyncio async def test_execute_workflow_function_async(self, workflow_engine): """Test asynchronous workflow function execution.""" + async def async_workflow_func(input_data): return f"async-processed: {input_data}" - + input_data = "test-input" - + # Execute the async workflow function - result = await workflow_engine._execute_workflow_function_once(async_workflow_func, input_data) - + result = await workflow_engine._execute_workflow_function_once( + async_workflow_func, input_data + ) + # Verify the result assert result == "async-processed: test-input" @@ -245,12 +292,14 @@ async def async_workflow_func(input_data): async def test_execute_workflow_function_none(self, workflow_engine): """Test workflow function execution with None function.""" input_data = "test-input" - + # Execute with None workflow function - should raise TypeError with pytest.raises(TypeError, match="'NoneType' object is not callable"): await workflow_engine._execute_workflow_function_once(None, input_data) - def test_workflow_engine_initialization(self, workflow_engine, workflow_info, mock_client, mock_workflow_definition): + def test_workflow_engine_initialization( + self, workflow_engine, workflow_info, mock_client, mock_workflow_definition + ): """Test WorkflowEngine initialization.""" assert workflow_engine._context is not None assert workflow_engine._workflow_definition == mock_workflow_definition @@ -259,26 +308,30 @@ def test_workflow_engine_initialization(self, workflow_engine, workflow_info, mo assert workflow_engine._is_workflow_complete is False @pytest.mark.asyncio - async def test_workflow_engine_without_workflow_definition(self, mock_client, workflow_info): + async def test_workflow_engine_without_workflow_definition( + self, mock_client, workflow_info + ): """Test WorkflowEngine without workflow definition.""" engine = WorkflowEngine( - info=workflow_info, - client=mock_client, - workflow_definition=None + info=workflow_info, client=mock_client, workflow_definition=None ) - + decision_task = self.create_mock_decision_task() - - with patch.object(engine._decision_manager, 'collect_pending_decisions', return_value=[]): + + with patch.object( + engine._decision_manager, "collect_pending_decisions", return_value=[] + ): # Process the decision result = await engine.process_decision(decision_task) - + # Verify the result assert isinstance(result, DecisionResult) assert len(result.decisions) == 0 @pytest.mark.asyncio - async def test_workflow_engine_workflow_completion(self, workflow_engine, mock_client): + async def test_workflow_engine_workflow_completion( + self, workflow_engine, mock_client + ): """Test workflow completion detection.""" decision_task = self.create_mock_decision_task() @@ -289,16 +342,22 @@ async def run(self, input_data): return "workflow-completed" workflow_opts = WorkflowDefinitionOptions(name="completing_workflow") - completing_definition = WorkflowDefinition.wrap(CompletingWorkflow, workflow_opts) + completing_definition = WorkflowDefinition.wrap( + CompletingWorkflow, workflow_opts + ) # Replace the workflow definition and instance workflow_engine._workflow_definition = completing_definition workflow_engine._workflow_instance = completing_definition.cls() - - with patch.object(workflow_engine._decision_manager, 'collect_pending_decisions', return_value=[]): + + with patch.object( + workflow_engine._decision_manager, + "collect_pending_decisions", + return_value=[], + ): # Process the decision await workflow_engine.process_decision(decision_task) - + # Verify workflow is marked as complete assert workflow_engine._is_workflow_complete is True @@ -308,18 +367,26 @@ def test_close_event_loop(self, workflow_engine): workflow_engine._close_event_loop() @pytest.mark.asyncio - async def test_process_decision_with_query_results(self, workflow_engine, mock_client): + async def test_process_decision_with_query_results( + self, workflow_engine, mock_client + ): """Test decision processing with query results.""" decision_task = self.create_mock_decision_task() - + # Mock the decision manager to return decisions with query results mock_decisions = [Mock()] - - with patch.object(workflow_engine._decision_manager, 'collect_pending_decisions', return_value=mock_decisions): + + with patch.object( + workflow_engine._decision_manager, + "collect_pending_decisions", + return_value=mock_decisions, + ): # Process the decision result = await workflow_engine.process_decision(decision_task) - + # Verify the result assert isinstance(result, DecisionResult) assert len(result.decisions) == 1 - # Not set in this test + + +# Not set in this test diff --git a/tests/cadence/common_activities.py b/tests/cadence/common_activities.py index be78c62..a183e42 100644 --- a/tests/cadence/common_activities.py +++ b/tests/cadence/common_activities.py @@ -7,24 +7,28 @@ def simple_fn() -> None: pass + @activity.defn def no_parens() -> None: pass + @activity.defn() def echo(incoming: str) -> str: return incoming + @activity.defn(name="renamed") def renamed_fn() -> None: pass + @activity.defn() async def async_fn() -> None: pass -class Activities: +class Activities: @activity.defn() def echo_sync(self, incoming: str) -> str: return incoming @@ -33,10 +37,11 @@ def echo_sync(self, incoming: str) -> str: async def echo_async(self, incoming: str) -> str: return incoming + class ActivityInterface: @activity.defn() - def do_something(self) -> str: - ... + def do_something(self) -> str: ... + @dataclass class ActivityImpl(ActivityInterface): @@ -45,7 +50,8 @@ class ActivityImpl(ActivityInterface): def do_something(self) -> str: return self.result + class InvalidImpl(ActivityInterface): @activity.defn(name="something else entirely") def do_something(self) -> str: - return "hehe" \ No newline at end of file + return "hehe" diff --git a/tests/cadence/data_converter_test.py b/tests/cadence/data_converter_test.py index 068f3a5..91a2ff4 100644 --- a/tests/cadence/data_converter_test.py +++ b/tests/cadence/data_converter_test.py @@ -7,103 +7,98 @@ from cadence.data_converter import DefaultDataConverter from msgspec import json + @dataclasses.dataclass class _TestDataClass: foo: str = "foo" bar: int = -1 - baz: Optional['_TestDataClass'] = None + baz: Optional["_TestDataClass"] = None + @pytest.mark.parametrize( "json,types,expected", [ - pytest.param( - '"Hello world"', [str], ["Hello world"], id="happy path" - ), + pytest.param('"Hello world"', [str], ["Hello world"], id="happy path"), pytest.param( '"Hello" "world"', [str, str], ["Hello", "world"], id="space delimited" ), + pytest.param("1", [int, int], [1, 0], id="ints"), + pytest.param("1.5", [float, float], [1.5, 0.0], id="floats"), + pytest.param("true", [bool, bool], [True, False], id="bools"), pytest.param( - "1", [int, int], [1, 0], id="ints" - ), - pytest.param( - "1.5", [float, float], [1.5, 0.0], id="floats" - ), - pytest.param( - "true", [bool, bool], [True, False], id="bools" + '{"foo": "hello world", "bar": 42, "baz": {"bar": 43}}', + [_TestDataClass, _TestDataClass], + [_TestDataClass("hello world", 42, _TestDataClass(bar=43)), None], + id="data classes", ), pytest.param( - '{"foo": "hello world", "bar": 42, "baz": {"bar": 43}}', [_TestDataClass, _TestDataClass], [_TestDataClass("hello world", 42, _TestDataClass(bar=43)), None], id="data classes" + '{"foo": "hello world"}', + [dict, dict], + [{"foo": "hello world"}, None], + id="dicts", ), pytest.param( - '{"foo": "hello world"}', [dict, dict], [{"foo": "hello world"}, None], id="dicts" - ), - pytest.param( - '{"foo": 52}', [dict[str, int], dict], [{"foo": 52}, None], id="generic dicts" + '{"foo": 52}', + [dict[str, int], dict], + [{"foo": 52}, None], + id="generic dicts", ), pytest.param( '["hello"]', [list[str], list[str]], [["hello"], None], id="lists" ), - pytest.param( - '["hello"]', [set[str], set[str]], [{"hello"}, None], id="sets" - ), + pytest.param('["hello"]', [set[str], set[str]], [{"hello"}, None], id="sets"), pytest.param( '["hello", "world"]', [list[str]], [["hello", "world"]], id="list" ), pytest.param( - '{"foo": "bar"} {"bar": 100} ["hello"] "world"', [_TestDataClass, _TestDataClass, list[str], str], - [_TestDataClass(foo="bar"), _TestDataClass(bar=100), ["hello"], "world"], id="space delimited mix" + '{"foo": "bar"} {"bar": 100} ["hello"] "world"', + [_TestDataClass, _TestDataClass, list[str], str], + [_TestDataClass(foo="bar"), _TestDataClass(bar=100), ["hello"], "world"], + id="space delimited mix", ), + pytest.param("", [], [], id="no input expected"), + pytest.param("", [str], [None], id="no input unexpected"), pytest.param( - "", [], [], id="no input expected" + '"hello world" {"foo":"bar"} 7', + [None, None, None], + ["hello world", {"foo": "bar"}, 7], + id="no type hints", ), pytest.param( - "", [str], [None], id="no input unexpected" + '"hello" "world" "goodbye"', + [str, str], + ["hello", "world"], + id="extra content", ), - pytest.param( - '"hello world" {"foo":"bar"} 7', [None, None, None], ["hello world", {"foo":"bar"}, 7], id="no type hints" - ), - pytest.param( - '"hello" "world" "goodbye"', [str, str], ["hello", "world"], - id="extra content" - ), - ] + ], ) -async def test_data_converter_from_data(json: str, types: list[Type], expected: list[Any]) -> None: +async def test_data_converter_from_data( + json: str, types: list[Type], expected: list[Any] +) -> None: converter = DefaultDataConverter() actual = await converter.from_data(Payload(data=json.encode()), types) assert expected == actual + @pytest.mark.parametrize( "values,expected", [ + pytest.param(["hello world"], '"hello world"', id="happy path"), + pytest.param(["hello", "world"], '"hello" "world"', id="multiple values"), + pytest.param([[["hello"]], ["world"]], '[["hello"]] ["world"]', id="lists"), + pytest.param([1, 2, 10], "1 2 10", id="numeric values"), + pytest.param([True, False], "true false", id="bool values"), pytest.param( - ["hello world"], '"hello world"', id="happy path" - ), - pytest.param( - ["hello", "world"], '"hello" "world"', id="multiple values" - ), - pytest.param( - [[["hello"]], ["world"]], '[["hello"]] ["world"]', id="lists" - ), - pytest.param( - [1, 2, 10], '1 2 10', id="numeric values" - ), - pytest.param( - [True, False], 'true false', id="bool values" - ), - pytest.param( - [{'foo': 'foo', 'bar': 20}], '{"bar":20,"foo":"foo"}', id="dict values" - ), - pytest.param( - [{'foo', 'bar'}], '["bar","foo"]', id="set values" + [{"foo": "foo", "bar": 20}], '{"bar":20,"foo":"foo"}', id="dict values" ), + pytest.param([{"foo", "bar"}], '["bar","foo"]', id="set values"), pytest.param( [_TestDataClass()], '{"foo":"foo","bar":-1,"baz":null}', id="data classes" ), - ] + ], ) async def test_data_converter_to_data(values: list[Any], expected: str) -> None: converter = DefaultDataConverter() - converter._encoder = json.Encoder(order='deterministic') + converter._encoder = json.Encoder(order="deterministic") actual = await converter.to_data(values) - assert actual.data.decode() == expected \ No newline at end of file + assert actual.data.decode() == expected diff --git a/tests/cadence/metrics/test_metrics.py b/tests/cadence/metrics/test_metrics.py index fdf4bd6..1b29cc9 100644 --- a/tests/cadence/metrics/test_metrics.py +++ b/tests/cadence/metrics/test_metrics.py @@ -34,10 +34,7 @@ def test_mock_emitter(self): # Test gauge mock_emitter.gauge("test_gauge", 100.0, {"env": "test"}) - mock_emitter.gauge.assert_called_once_with( - "test_gauge", 100.0, {"env": "test"} - ) - + mock_emitter.gauge.assert_called_once_with("test_gauge", 100.0, {"env": "test"}) # Test histogram mock_emitter.histogram("test_histogram", 2.5, {"env": "prod"}) @@ -46,8 +43,6 @@ def test_mock_emitter(self): ) - - class TestMetricType: """Test cases for MetricType enum.""" diff --git a/tests/cadence/metrics/test_prometheus.py b/tests/cadence/metrics/test_prometheus.py index ad561e1..1adf918 100644 --- a/tests/cadence/metrics/test_prometheus.py +++ b/tests/cadence/metrics/test_prometheus.py @@ -20,12 +20,9 @@ def test_default_config(self): def test_custom_config(self): """Test custom configuration values.""" from prometheus_client import CollectorRegistry - + registry = CollectorRegistry() - config = PrometheusConfig( - default_labels={"env": "test"}, - registry=registry - ) + config = PrometheusConfig(default_labels={"env": "test"}, registry=registry) assert config.default_labels == {"env": "test"} assert config.registry is registry @@ -41,62 +38,58 @@ def test_init_with_default_config(self): def test_init_with_custom_config(self): """Test initialization with custom config.""" from prometheus_client import CollectorRegistry - + registry = CollectorRegistry() - config = PrometheusConfig( - default_labels={"service": "test"}, - registry=registry - ) + config = PrometheusConfig(default_labels={"service": "test"}, registry=registry) metrics = PrometheusMetrics(config) assert metrics.registry is registry - @patch('cadence.metrics.prometheus.Counter') + @patch("cadence.metrics.prometheus.Counter") def test_counter_metric(self, mock_counter_class): """Test counter metric creation and usage.""" mock_counter = Mock() mock_counter_class.return_value = mock_counter - + metrics = PrometheusMetrics() metrics.counter("test_counter", 5, {"label": "value"}) - + # Verify counter was created mock_counter_class.assert_called_once() mock_counter.labels.assert_called_once_with(label="value") mock_counter.labels.return_value.inc.assert_called_once_with(5) - @patch('cadence.metrics.prometheus.Gauge') + @patch("cadence.metrics.prometheus.Gauge") def test_gauge_metric(self, mock_gauge_class): """Test gauge metric creation and usage.""" mock_gauge = Mock() mock_gauge_class.return_value = mock_gauge - + metrics = PrometheusMetrics() metrics.gauge("test_gauge", 42.5, {"env": "prod"}) - + # Verify gauge was created mock_gauge_class.assert_called_once() mock_gauge.labels.assert_called_once_with(env="prod") mock_gauge.labels.return_value.set.assert_called_once_with(42.5) - @patch('cadence.metrics.prometheus.Histogram') + @patch("cadence.metrics.prometheus.Histogram") def test_histogram_metric(self, mock_histogram_class): """Test histogram metric creation and usage.""" mock_histogram = Mock() mock_histogram_class.return_value = mock_histogram - + metrics = PrometheusMetrics() metrics.histogram("test_histogram", 1.5, {"type": "latency"}) - + # Verify histogram was created mock_histogram_class.assert_called_once() mock_histogram.labels.assert_called_once_with(type="latency") mock_histogram.labels.return_value.observe.assert_called_once_with(1.5) - def test_metric_name_generation(self): """Test metric name generation.""" metrics = PrometheusMetrics() - + metric_name = metrics._get_metric_name("test_metric") assert metric_name == "test_metric" @@ -106,42 +99,48 @@ def test_label_merging(self): default_labels={"service": "cadence", "version": "1.0"} ) metrics = PrometheusMetrics(config) - + # Test merging with provided labels merged = metrics._merge_labels({"operation": "start"}) expected = {"service": "cadence", "version": "1.0", "operation": "start"} assert merged == expected - + # Test merging with None labels merged_none = metrics._merge_labels(None) assert merged_none == {"service": "cadence", "version": "1.0"} - @patch('cadence.metrics.prometheus.generate_latest') + @patch("cadence.metrics.prometheus.generate_latest") def test_get_metrics_text(self, mock_generate_latest): """Test getting metrics in text format.""" mock_generate_latest.return_value = b"# HELP test_metric Test metric\n# TYPE test_metric counter\ntest_metric 1.0\n" - + metrics = PrometheusMetrics() result = metrics.get_metrics_text() - - assert result == "# HELP test_metric Test metric\n# TYPE test_metric counter\ntest_metric 1.0\n" + + assert ( + result + == "# HELP test_metric Test metric\n# TYPE test_metric counter\ntest_metric 1.0\n" + ) mock_generate_latest.assert_called_once_with(metrics.registry) def test_error_handling_in_counter(self): """Test error handling in counter method.""" metrics = PrometheusMetrics() - + # This should not raise an exception - with patch.object(metrics, '_get_or_create_counter', side_effect=Exception("Test error")): + with patch.object( + metrics, "_get_or_create_counter", side_effect=Exception("Test error") + ): metrics.counter("test_counter", 1) # Should not raise, just log error def test_error_handling_in_gauge(self): """Test error handling in gauge method.""" metrics = PrometheusMetrics() - + # This should not raise an exception - with patch.object(metrics, '_get_or_create_gauge', side_effect=Exception("Test error")): + with patch.object( + metrics, "_get_or_create_gauge", side_effect=Exception("Test error") + ): metrics.gauge("test_gauge", 1.0) # Should not raise, just log error - diff --git a/tests/cadence/worker/test_base_task_handler.py b/tests/cadence/worker/test_base_task_handler.py index d5d48a6..6d1077c 100644 --- a/tests/cadence/worker/test_base_task_handler.py +++ b/tests/cadence/worker/test_base_task_handler.py @@ -11,21 +11,21 @@ class ConcreteTaskHandler(BaseTaskHandler[str]): """Concrete implementation of BaseTaskHandler for testing.""" - + def __init__(self, client, task_list: str, identity: str, **options): super().__init__(client, task_list, identity, **options) self._handle_task_implementation_called = False self._handle_task_failure_called = False self._last_task: str = "" self._last_error: Exception | None = None - + async def _handle_task_implementation(self, task: str) -> None: """Test implementation of task handling.""" self._handle_task_implementation_called = True self._last_task = task if task == "raise_error": raise ValueError("Test error") - + async def handle_task_failure(self, task: str, error: Exception) -> None: """Test implementation of task failure handling.""" self._handle_task_failure_called = True @@ -35,7 +35,7 @@ async def handle_task_failure(self, task: str, error: Exception) -> None: class TestBaseTaskHandler: """Test cases for BaseTaskHandler.""" - + def test_initialization(self): """Test BaseTaskHandler initialization.""" client = Mock() @@ -44,81 +44,79 @@ def test_initialization(self): task_list="test_task_list", identity="test_identity", option1="value1", - option2="value2" + option2="value2", ) - + assert handler._client == client assert handler._task_list == "test_task_list" assert handler._identity == "test_identity" assert handler._options == {"option1": "value1", "option2": "value2"} - + @pytest.mark.asyncio async def test_handle_task_success(self): """Test successful task handling.""" client = Mock() handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") - + await handler.handle_task("test_task") - + # Verify implementation was called assert handler._handle_task_implementation_called assert not handler._handle_task_failure_called assert handler._last_task == "test_task" assert handler._last_error is None - + @pytest.mark.asyncio async def test_handle_task_failure(self): """Test task handling with error.""" client = Mock() handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") - + await handler.handle_task("raise_error") - + # Verify error handling was called assert handler._handle_task_implementation_called assert handler._handle_task_failure_called assert handler._last_task == "raise_error" assert isinstance(handler._last_error, ValueError) assert str(handler._last_error) == "Test error" - - + @pytest.mark.asyncio async def test_abstract_methods_not_implemented(self): """Test that abstract methods raise NotImplementedError when not implemented.""" client = Mock() - + class IncompleteHandler(BaseTaskHandler[str]): async def _handle_task_implementation(self, task: str) -> None: raise NotImplementedError() - + async def handle_task_failure(self, task: str, error: Exception) -> None: raise NotImplementedError() - + handler = IncompleteHandler(client, "test_task_list", "test_identity") - + with pytest.raises(NotImplementedError): await handler._handle_task_implementation("test") - + with pytest.raises(NotImplementedError): await handler.handle_task_failure("test", Exception("test")) - - + @pytest.mark.asyncio async def test_generic_type_parameter(self): """Test that the generic type parameter works correctly.""" client = Mock() - + class IntHandler(BaseTaskHandler[int]): async def _handle_task_implementation(self, task: int) -> None: pass - + async def handle_task_failure(self, task: int, error: Exception) -> None: pass - + handler = IntHandler(client, "test_task_list", "test_identity") - + # Should accept int tasks await handler.handle_task(42) - + # Type checker should catch type mismatches (this is more of a static analysis test) # In runtime, Python won't enforce the type, but the type hints are there for static analysis diff --git a/tests/cadence/worker/test_decision_task_handler.py b/tests/cadence/worker/test_decision_task_handler.py index da2e79a..b7c0957 100644 --- a/tests/cadence/worker/test_decision_task_handler.py +++ b/tests/cadence/worker/test_decision_task_handler.py @@ -9,7 +9,7 @@ from cadence.api.v1.common_pb2 import Payload from cadence.api.v1.service_worker_pb2 import ( PollForDecisionTaskResponse, - RespondDecisionTaskCompletedRequest + RespondDecisionTaskCompletedRequest, ) from cadence.api.v1.workflow_pb2 import DecisionTaskFailedCause from cadence.api.v1.decision_pb2 import Decision @@ -23,7 +23,7 @@ class TestDecisionTaskHandler: """Test cases for DecisionTaskHandler.""" - + @pytest.fixture def mock_client(self): """Create a mock client.""" @@ -33,13 +33,13 @@ def mock_client(self): client.worker_stub.RespondDecisionTaskFailed = AsyncMock() type(client).domain = PropertyMock(return_value="test_domain") return client - + @pytest.fixture def mock_registry(self): """Create a mock registry.""" registry = Mock(spec=Registry) return registry - + @pytest.fixture def handler(self, mock_client, mock_registry): """Create a DecisionTaskHandler instance.""" @@ -47,9 +47,9 @@ def handler(self, mock_client, mock_registry): client=mock_client, task_list="test_task_list", registry=mock_registry, - identity="test_identity" + identity="test_identity", ) - + @pytest.fixture def sample_decision_task(self): """Create a sample decision task.""" @@ -64,7 +64,7 @@ def sample_decision_task(self): task.started_event_id = 1 task.attempt = 1 return task - + def test_initialization(self, mock_client, mock_registry): """Test DecisionTaskHandler initialization.""" handler = DecisionTaskHandler( @@ -72,18 +72,21 @@ def test_initialization(self, mock_client, mock_registry): task_list="test_task_list", registry=mock_registry, identity="test_identity", - option1="value1" + option1="value1", ) - + assert handler._client == mock_client assert handler._task_list == "test_task_list" assert handler._identity == "test_identity" assert handler._registry == mock_registry assert handler._options == {"option1": "value1"} - + @pytest.mark.asyncio - async def test_handle_task_implementation_success(self, handler, sample_decision_task, mock_registry): + async def test_handle_task_implementation_success( + self, handler, sample_decision_task, mock_registry + ): """Test successful decision task handling.""" + # Create actual workflow definition class MockWorkflow: @workflow.run @@ -93,7 +96,7 @@ async def run(self): workflow_opts = WorkflowDefinitionOptions(name="test_workflow") workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) mock_registry.get_workflow.return_value = workflow_definition - + # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute @@ -101,19 +104,22 @@ async def run(self): mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [Decision()] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine", + return_value=mock_engine, + ): await handler._handle_task_implementation(sample_decision_task) - + # Verify registry was called mock_registry.get_workflow.assert_called_once_with("TestWorkflow") - + # Verify workflow engine was created and used mock_engine.process_decision.assert_called_once_with(sample_decision_task) - + # Verify response was sent handler._client.worker_stub.RespondDecisionTaskCompleted.assert_called_once() - + @pytest.mark.asyncio async def test_handle_task_implementation_missing_workflow_execution(self, handler): """Test decision task handling with missing workflow execution.""" @@ -122,10 +128,10 @@ async def test_handle_task_implementation_missing_workflow_execution(self, handl task.workflow_execution = None task.workflow_type = Mock() task.workflow_type.name = "TestWorkflow" - + with pytest.raises(ValueError, match="Missing workflow execution or type"): await handler._handle_task_implementation(task) - + @pytest.mark.asyncio async def test_handle_task_implementation_missing_workflow_type(self, handler): """Test decision task handling with missing workflow type.""" @@ -135,21 +141,26 @@ async def test_handle_task_implementation_missing_workflow_type(self, handler): task.workflow_execution.workflow_id = "test_workflow_id" task.workflow_execution.run_id = "test_run_id" task.workflow_type = None - + with pytest.raises(ValueError, match="Missing workflow execution or type"): await handler._handle_task_implementation(task) - + @pytest.mark.asyncio - async def test_handle_task_implementation_workflow_not_found(self, handler, sample_decision_task, mock_registry): + async def test_handle_task_implementation_workflow_not_found( + self, handler, sample_decision_task, mock_registry + ): """Test decision task handling when workflow is not found in registry.""" mock_registry.get_workflow.side_effect = KeyError("Workflow not found") - + with pytest.raises(KeyError, match="Workflow type 'TestWorkflow' not found"): await handler._handle_task_implementation(sample_decision_task) - + @pytest.mark.asyncio - async def test_handle_task_implementation_caches_engines(self, handler, sample_decision_task, mock_registry): + async def test_handle_task_implementation_caches_engines( + self, handler, sample_decision_task, mock_registry + ): """Test that decision task handler caches workflow engines for same workflow execution.""" + # Create actual workflow definition class MockWorkflow: @workflow.run @@ -159,33 +170,39 @@ async def run(self): workflow_opts = WorkflowDefinitionOptions(name="test_workflow") workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) mock_registry.get_workflow.return_value = workflow_definition - + # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class: + + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine", + return_value=mock_engine, + ) as mock_engine_class: # First call - should create new engine await handler._handle_task_implementation(sample_decision_task) - + # Second call with same workflow_id and run_id - should reuse cached engine await handler._handle_task_implementation(sample_decision_task) - + # Registry should be called for each task (to get workflow function) assert mock_registry.get_workflow.call_count == 2 - + # Engine should be created only once (cached for second call) assert mock_engine_class.call_count == 1 - + # But process_decision should be called twice assert mock_engine.process_decision.call_count == 2 - + @pytest.mark.asyncio - async def test_handle_task_implementation_different_executions_get_separate_engines(self, handler, mock_registry): + async def test_handle_task_implementation_different_executions_get_separate_engines( + self, handler, mock_registry + ): """Test that different workflow executions get separate engines.""" + # Create actual workflow definition class MockWorkflow: @workflow.run @@ -195,7 +212,7 @@ async def run(self): workflow_opts = WorkflowDefinitionOptions(name="test_workflow") workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) mock_registry.get_workflow.return_value = workflow_definition - + # Create two different decision tasks task1 = Mock(spec=PollForDecisionTaskResponse) task1.task_token = b"test_task_token_1" @@ -206,143 +223,191 @@ async def run(self): task1.workflow_type.name = "TestWorkflow" task1.started_event_id = 1 task1.attempt = 1 - + task2 = Mock(spec=PollForDecisionTaskResponse) task2.task_token = b"test_task_token_2" task2.workflow_execution = Mock() task2.workflow_execution.workflow_id = "workflow_2" # Different workflow - task2.workflow_execution.run_id = "run_2" # Different run + task2.workflow_execution.run_id = "run_2" # Different run task2.workflow_type = Mock() task2.workflow_type.name = "TestWorkflow" task2.started_event_id = 2 task2.attempt = 1 - + # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class: + + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine", + return_value=mock_engine, + ) as mock_engine_class: # Process different workflow executions await handler._handle_task_implementation(task1) await handler._handle_task_implementation(task2) - + # Registry should be called for each task assert mock_registry.get_workflow.call_count == 2 - + # Engine should be created twice (different executions) assert mock_engine_class.call_count == 2 - + # Process_decision should be called twice assert mock_engine.process_decision.call_count == 2 - + @pytest.mark.asyncio async def test_handle_task_failure_keyerror(self, handler, sample_decision_task): """Test task failure handling for KeyError.""" error = KeyError("Workflow not found") - + await handler.handle_task_failure(sample_decision_task, error) - + # Verify the correct failure cause was used - call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] - assert call_args.cause == DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][ + 0 + ] + assert ( + call_args.cause + == DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE + ) assert call_args.task_token == sample_decision_task.task_token assert call_args.identity == handler._identity - + @pytest.mark.asyncio async def test_handle_task_failure_valueerror(self, handler, sample_decision_task): """Test task failure handling for ValueError.""" error = ValueError("Invalid workflow attributes") - + await handler.handle_task_failure(sample_decision_task, error) - + # Verify the correct failure cause was used - call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] - assert call_args.cause == DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_BAD_SCHEDULE_ACTIVITY_ATTRIBUTES + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][ + 0 + ] + assert ( + call_args.cause + == DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_BAD_SCHEDULE_ACTIVITY_ATTRIBUTES + ) assert call_args.task_token == sample_decision_task.task_token assert call_args.identity == handler._identity - + @pytest.mark.asyncio - async def test_handle_task_failure_generic_error(self, handler, sample_decision_task): + async def test_handle_task_failure_generic_error( + self, handler, sample_decision_task + ): """Test task failure handling for generic error.""" error = RuntimeError("Generic error") - + await handler.handle_task_failure(sample_decision_task, error) - + # Verify the default failure cause was used - call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] - assert call_args.cause == DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][ + 0 + ] + assert ( + call_args.cause + == DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION + ) assert call_args.task_token == sample_decision_task.task_token assert call_args.identity == handler._identity - + @pytest.mark.asyncio - async def test_handle_task_failure_with_error_details(self, handler, sample_decision_task): + async def test_handle_task_failure_with_error_details( + self, handler, sample_decision_task + ): """Test task failure handling includes error details.""" error = ValueError("Test error message") - + await handler.handle_task_failure(sample_decision_task, error) - - call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][ + 0 + ] assert isinstance(call_args.details, Payload) assert call_args.details.data == b"Test error message" - + @pytest.mark.asyncio - async def test_handle_task_failure_respond_error(self, handler, sample_decision_task): + async def test_handle_task_failure_respond_error( + self, handler, sample_decision_task + ): """Test task failure handling when respond fails.""" error = ValueError("Test error") - handler._client.worker_stub.RespondDecisionTaskFailed.side_effect = Exception("Respond failed") - + handler._client.worker_stub.RespondDecisionTaskFailed.side_effect = Exception( + "Respond failed" + ) + # Should not raise exception, but should log error - with patch('cadence.worker._decision_task_handler.logger') as mock_logger: + with patch("cadence.worker._decision_task_handler.logger") as mock_logger: await handler.handle_task_failure(sample_decision_task, error) # Now uses logger.error with exc_info=True instead of logger.exception mock_logger.error.assert_called() - + @pytest.mark.asyncio - async def test_respond_decision_task_completed_success(self, handler, sample_decision_task): + async def test_respond_decision_task_completed_success( + self, handler, sample_decision_task + ): """Test successful decision task completion response.""" decision_result = Mock(spec=DecisionResult) decision_result.decisions = [Decision(), Decision()] - - await handler._respond_decision_task_completed(sample_decision_task, decision_result) - + + await handler._respond_decision_task_completed( + sample_decision_task, decision_result + ) + # Verify the request was created correctly - call_args = handler._client.worker_stub.RespondDecisionTaskCompleted.call_args[0][0] + call_args = handler._client.worker_stub.RespondDecisionTaskCompleted.call_args[ + 0 + ][0] assert isinstance(call_args, RespondDecisionTaskCompletedRequest) assert call_args.task_token == sample_decision_task.task_token assert call_args.identity == handler._identity assert call_args.return_new_decision_task assert len(call_args.decisions) == 2 - + @pytest.mark.asyncio - async def test_respond_decision_task_completed_no_query_results(self, handler, sample_decision_task): + async def test_respond_decision_task_completed_no_query_results( + self, handler, sample_decision_task + ): """Test decision task completion response without query results.""" decision_result = Mock(spec=DecisionResult) decision_result.decisions = [] - - await handler._respond_decision_task_completed(sample_decision_task, decision_result) - - call_args = handler._client.worker_stub.RespondDecisionTaskCompleted.call_args[0][0] + + await handler._respond_decision_task_completed( + sample_decision_task, decision_result + ) + + call_args = handler._client.worker_stub.RespondDecisionTaskCompleted.call_args[ + 0 + ][0] assert call_args.return_new_decision_task assert len(call_args.decisions) == 0 - + @pytest.mark.asyncio - async def test_respond_decision_task_completed_error(self, handler, sample_decision_task): + async def test_respond_decision_task_completed_error( + self, handler, sample_decision_task + ): """Test decision task completion response error handling.""" decision_result = Mock(spec=DecisionResult) decision_result.decisions = [] - - handler._client.worker_stub.RespondDecisionTaskCompleted.side_effect = Exception("Respond failed") - + + handler._client.worker_stub.RespondDecisionTaskCompleted.side_effect = ( + Exception("Respond failed") + ) + with pytest.raises(Exception, match="Respond failed"): - await handler._respond_decision_task_completed(sample_decision_task, decision_result) - - + await handler._respond_decision_task_completed( + sample_decision_task, decision_result + ) + @pytest.mark.asyncio - async def test_workflow_engine_creation_with_workflow_info(self, handler, sample_decision_task, mock_registry): + async def test_workflow_engine_creation_with_workflow_info( + self, handler, sample_decision_task, mock_registry + ): """Test that WorkflowEngine is created with correct WorkflowInfo.""" + # Create actual workflow definition class MockWorkflow: @workflow.run @@ -359,23 +424,28 @@ async def run(self): mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_workflow_engine_class: - with patch('cadence.worker._decision_task_handler.WorkflowInfo') as mock_workflow_info_class: + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine", + return_value=mock_engine, + ) as mock_workflow_engine_class: + with patch( + "cadence.worker._decision_task_handler.WorkflowInfo" + ) as mock_workflow_info_class: await handler._handle_task_implementation(sample_decision_task) # Verify WorkflowInfo was created with correct parameters (called once for engine) assert mock_workflow_info_class.call_count == 1 for call in mock_workflow_info_class.call_args_list: assert call[1] == { - 'workflow_type': "TestWorkflow", - 'workflow_domain': "test_domain", - 'workflow_id': "test_workflow_id", - 'workflow_run_id': "test_run_id" + "workflow_type": "TestWorkflow", + "workflow_domain": "test_domain", + "workflow_id": "test_workflow_id", + "workflow_run_id": "test_run_id", } # Verify WorkflowEngine was created with correct parameters mock_workflow_engine_class.assert_called_once() call_args = mock_workflow_engine_class.call_args - assert call_args[1]['info'] is not None - assert call_args[1]['client'] == handler._client - assert call_args[1]['workflow_definition'] == workflow_definition + assert call_args[1]["info"] is not None + assert call_args[1]["client"] == handler._client + assert call_args[1]["workflow_definition"] == workflow_definition diff --git a/tests/cadence/worker/test_decision_task_handler_integration.py b/tests/cadence/worker/test_decision_task_handler_integration.py index fc65f0e..614d121 100644 --- a/tests/cadence/worker/test_decision_task_handler_integration.py +++ b/tests/cadence/worker/test_decision_task_handler_integration.py @@ -5,11 +5,13 @@ import pytest from unittest.mock import Mock, AsyncMock, patch -from cadence.api.v1.service_worker_pb2 import ( - PollForDecisionTaskResponse -) +from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse from cadence.api.v1.common_pb2 import Payload, WorkflowExecution, WorkflowType -from cadence.api.v1.history_pb2 import History, HistoryEvent, WorkflowExecutionStartedEventAttributes +from cadence.api.v1.history_pb2 import ( + History, + HistoryEvent, + WorkflowExecutionStartedEventAttributes, +) from cadence.api.v1.decision_pb2 import Decision from cadence.worker._decision_task_handler import DecisionTaskHandler from cadence.worker._registry import Registry @@ -53,137 +55,170 @@ def decision_task_handler(self, mock_client, registry): client=mock_client, task_list="test-task-list", registry=registry, - identity="test-worker" + identity="test-worker", ) - def create_mock_decision_task(self, workflow_id="test-workflow", run_id="test-run", workflow_type="test_workflow"): + def create_mock_decision_task( + self, + workflow_id="test-workflow", + run_id="test-run", + workflow_type="test_workflow", + ): """Create a mock decision task with history.""" # Create workflow execution workflow_execution = WorkflowExecution() workflow_execution.workflow_id = workflow_id workflow_execution.run_id = run_id - + # Create workflow type workflow_type_obj = WorkflowType() workflow_type_obj.name = workflow_type - + # Create workflow execution started event started_event = WorkflowExecutionStartedEventAttributes() input_payload = Payload(data=b'"test-input"') started_event.input.CopyFrom(input_payload) - + history_event = HistoryEvent() - history_event.workflow_execution_started_event_attributes.CopyFrom(started_event) - + history_event.workflow_execution_started_event_attributes.CopyFrom( + started_event + ) + # Create history history = History() history.events.append(history_event) - + # Create decision task decision_task = PollForDecisionTaskResponse() decision_task.task_token = b"test-task-token" decision_task.workflow_execution.CopyFrom(workflow_execution) decision_task.workflow_type.CopyFrom(workflow_type_obj) decision_task.history.CopyFrom(history) - + return decision_task @pytest.mark.asyncio - async def test_handle_decision_task_success(self, decision_task_handler, mock_client): + async def test_handle_decision_task_success( + self, decision_task_handler, mock_client + ): """Test successful decision task handling.""" # Create a mock decision task decision_task = self.create_mock_decision_task() - + # Mock the workflow engine to return some decisions # Mock the workflow engine creation and execution mock_engine = Mock() # Create a proper Decision object decision = Decision() - mock_engine.process_decision = AsyncMock(return_value=Mock( - decisions=[decision], # Proper Decision object - )) - - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + mock_engine.process_decision = AsyncMock( + return_value=Mock( + decisions=[decision], # Proper Decision object + ) + ) + + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine", + return_value=mock_engine, + ): # Handle the decision task await decision_task_handler._handle_task_implementation(decision_task) - + # Verify the workflow engine was called mock_engine.process_decision.assert_called_once_with(decision_task) - + # Verify the response was sent mock_client.worker_stub.RespondDecisionTaskCompleted.assert_called_once() @pytest.mark.asyncio - async def test_handle_decision_task_workflow_not_found(self, decision_task_handler, mock_client): + async def test_handle_decision_task_workflow_not_found( + self, decision_task_handler, mock_client + ): """Test decision task handling when workflow is not found in registry.""" # Create a decision task with unknown workflow type decision_task = self.create_mock_decision_task(workflow_type="unknown_workflow") - + # Handle the decision task await decision_task_handler.handle_task(decision_task) - + # Verify failure response was sent mock_client.worker_stub.RespondDecisionTaskFailed.assert_called_once() - + # Verify the failure request has the correct cause call_args = mock_client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] - assert call_args.cause == 14 # DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE + assert ( + call_args.cause == 14 + ) # DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE @pytest.mark.asyncio - async def test_handle_decision_task_missing_workflow_execution(self, decision_task_handler, mock_client): + async def test_handle_decision_task_missing_workflow_execution( + self, decision_task_handler, mock_client + ): """Test decision task handling when workflow execution is missing.""" # Create a decision task without workflow execution decision_task = PollForDecisionTaskResponse() decision_task.task_token = b"test-task-token" # No workflow_execution set - + # Handle the decision task await decision_task_handler.handle_task(decision_task) - + # Verify failure response was sent mock_client.worker_stub.RespondDecisionTaskFailed.assert_called_once() - + # Verify the failure request has the correct cause call_args = mock_client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] - assert call_args.cause == 14 # DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE + assert ( + call_args.cause == 14 + ) # DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE @pytest.mark.asyncio - async def test_workflow_engine_creation_each_task(self, decision_task_handler, mock_client): + async def test_workflow_engine_creation_each_task( + self, decision_task_handler, mock_client + ): """Test that workflow engines are created for each task.""" decision_task = self.create_mock_decision_task() - - with patch('cadence.worker._decision_task_handler.WorkflowEngine') as mock_engine_class: + + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine" + ) as mock_engine_class: mock_engine = Mock() - mock_engine.process_decision = AsyncMock(return_value=Mock( - decisions=[], - )) + mock_engine.process_decision = AsyncMock( + return_value=Mock( + decisions=[], + ) + ) mock_engine_class.return_value = mock_engine - + # Handle the same decision task twice await decision_task_handler._handle_task_implementation(decision_task) await decision_task_handler._handle_task_implementation(decision_task) - + # Verify engine was created twice (once for each task) assert mock_engine_class.call_count == 2 - + # Verify engine was called twice assert mock_engine.process_decision.call_count == 2 - @pytest.mark.asyncio - async def test_decision_task_failure_handling(self, decision_task_handler, mock_client): + async def test_decision_task_failure_handling( + self, decision_task_handler, mock_client + ): """Test decision task failure handling.""" decision_task = self.create_mock_decision_task() - + # Mock the workflow engine to raise an exception - with patch('cadence.worker._decision_task_handler.WorkflowEngine') as mock_engine_class: + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine" + ) as mock_engine_class: mock_engine = Mock() - mock_engine.process_decision = AsyncMock(side_effect=Exception("Test error")) + mock_engine.process_decision = AsyncMock( + side_effect=Exception("Test error") + ) mock_engine_class.return_value = mock_engine - + # Handle the decision task - this should catch the exception await decision_task_handler.handle_task(decision_task) - + # Verify failure response was sent mock_client.worker_stub.RespondDecisionTaskFailed.assert_called_once() @@ -193,20 +228,24 @@ def test_decision_task_handler_initialization(self, decision_task_handler): assert decision_task_handler._identity == "test-worker" @pytest.mark.asyncio - async def test_respond_decision_task_completed(self, decision_task_handler, mock_client): + async def test_respond_decision_task_completed( + self, decision_task_handler, mock_client + ): """Test decision task completion response.""" decision_task = self.create_mock_decision_task() - + # Create mock decision result decision_result = Mock() decision_result.decisions = [Decision()] # Proper Decision object - + # Call the response method - await decision_task_handler._respond_decision_task_completed(decision_task, decision_result) - + await decision_task_handler._respond_decision_task_completed( + decision_task, decision_result + ) + # Verify the response was sent mock_client.worker_stub.RespondDecisionTaskCompleted.assert_called_once() - + # Verify the request parameters call_args = mock_client.worker_stub.RespondDecisionTaskCompleted.call_args[0][0] assert call_args.task_token == b"test-task-token" @@ -214,17 +253,19 @@ async def test_respond_decision_task_completed(self, decision_task_handler, mock assert len(call_args.decisions) == 1 @pytest.mark.asyncio - async def test_respond_decision_task_failed(self, decision_task_handler, mock_client): + async def test_respond_decision_task_failed( + self, decision_task_handler, mock_client + ): """Test decision task failure response.""" decision_task = self.create_mock_decision_task() error = ValueError("Test error") - + # Call the failure method await decision_task_handler.handle_task_failure(decision_task, error) - + # Verify the failure response was sent mock_client.worker_stub.RespondDecisionTaskFailed.assert_called_once() - + # Verify the request parameters call_args = mock_client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] assert call_args.task_token == b"test-task-token" diff --git a/tests/cadence/worker/test_decision_worker_integration.py b/tests/cadence/worker/test_decision_worker_integration.py index 18e970e..91368b4 100644 --- a/tests/cadence/worker/test_decision_worker_integration.py +++ b/tests/cadence/worker/test_decision_worker_integration.py @@ -8,7 +8,11 @@ from unittest.mock import Mock, AsyncMock, patch from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse from cadence.api.v1.common_pb2 import Payload, WorkflowExecution, WorkflowType -from cadence.api.v1.history_pb2 import History, HistoryEvent, WorkflowExecutionStartedEventAttributes +from cadence.api.v1.history_pb2 import ( + History, + HistoryEvent, + WorkflowExecutionStartedEventAttributes, +) from cadence.worker._decision import DecisionWorker from cadence.worker._registry import Registry from cadence import workflow @@ -51,45 +55,52 @@ def decision_worker(self, mock_client, registry): options = { "identity": "test-worker", "max_concurrent_decision_task_execution_size": 1, - "decision_task_pollers": 1 + "decision_task_pollers": 1, } return DecisionWorker( client=mock_client, task_list="test-task-list", registry=registry, - options=options + options=options, ) - def create_mock_decision_task(self, workflow_id="test-workflow", run_id="test-run", workflow_type="test_workflow"): + def create_mock_decision_task( + self, + workflow_id="test-workflow", + run_id="test-run", + workflow_type="test_workflow", + ): """Create a mock decision task with history.""" # Create workflow execution workflow_execution = WorkflowExecution() workflow_execution.workflow_id = workflow_id workflow_execution.run_id = run_id - + # Create workflow type workflow_type_obj = WorkflowType() workflow_type_obj.name = workflow_type - + # Create workflow execution started event started_event = WorkflowExecutionStartedEventAttributes() input_payload = Payload(data=b'"test-input"') started_event.input.CopyFrom(input_payload) - + history_event = HistoryEvent() - history_event.workflow_execution_started_event_attributes.CopyFrom(started_event) - + history_event.workflow_execution_started_event_attributes.CopyFrom( + started_event + ) + # Create history history = History() history.events.append(history_event) - + # Create decision task decision_task = PollForDecisionTaskResponse() decision_task.task_token = b"test-task-token" decision_task.workflow_execution.CopyFrom(workflow_execution) decision_task.workflow_type.CopyFrom(workflow_type_obj) decision_task.history.CopyFrom(history) - + return decision_task @pytest.mark.asyncio @@ -97,21 +108,21 @@ async def test_decision_worker_poll_and_execute(self, decision_worker, mock_clie """Test decision worker polling and executing tasks.""" # Create a mock decision task decision_task = self.create_mock_decision_task() - + # Mock the poll to return the decision task mock_client.worker_stub.PollForDecisionTask.return_value = decision_task - + # Mock the decision handler - with patch.object(decision_worker, '_decision_handler') as mock_handler: + with patch.object(decision_worker, "_decision_handler") as mock_handler: mock_handler.handle_task = AsyncMock() - + # Run the poll and execute await decision_worker._poll() await decision_worker._execute(decision_task) - + # Verify the poll was called mock_client.worker_stub.PollForDecisionTask.assert_called_once() - + # Verify the handler was called mock_handler.handle_task.assert_called_once_with(decision_task) @@ -120,42 +131,46 @@ async def test_decision_worker_poll_no_task(self, decision_worker, mock_client): """Test decision worker polling when no task is available.""" # Mock the poll to return None (no task) mock_client.worker_stub.PollForDecisionTask.return_value = None - + # Run the poll result = await decision_worker._poll() - + # Verify no task was returned assert result is None @pytest.mark.asyncio - async def test_decision_worker_poll_with_task_token(self, decision_worker, mock_client): + async def test_decision_worker_poll_with_task_token( + self, decision_worker, mock_client + ): """Test decision worker polling when task has token.""" # Create a decision task with token decision_task = self.create_mock_decision_task() decision_task.task_token = b"valid-token" - + # Mock the poll to return the decision task mock_client.worker_stub.PollForDecisionTask.return_value = decision_task - + # Run the poll result = await decision_worker._poll() - + # Verify the task was returned assert result == decision_task @pytest.mark.asyncio - async def test_decision_worker_poll_without_task_token(self, decision_worker, mock_client): + async def test_decision_worker_poll_without_task_token( + self, decision_worker, mock_client + ): """Test decision worker polling when task has no token.""" # Create a decision task without token decision_task = self.create_mock_decision_task() decision_task.task_token = b"" # Empty token - + # Mock the poll to return the decision task mock_client.worker_stub.PollForDecisionTask.return_value = decision_task - + # Run the poll result = await decision_worker._poll() - + # Verify no task was returned assert result is None @@ -163,34 +178,38 @@ async def test_decision_worker_poll_without_task_token(self, decision_worker, mo async def test_decision_worker_execute_success(self, decision_worker, mock_client): """Test successful decision task execution.""" decision_task = self.create_mock_decision_task() - + # Mock the decision handler - with patch.object(decision_worker, '_decision_handler') as mock_handler: + with patch.object(decision_worker, "_decision_handler") as mock_handler: mock_handler.handle_task = AsyncMock() - + # Execute the task await decision_worker._execute(decision_task) - + # Verify the handler was called mock_handler.handle_task.assert_called_once_with(decision_task) @pytest.mark.asyncio - async def test_decision_worker_execute_handler_error(self, decision_worker, mock_client): + async def test_decision_worker_execute_handler_error( + self, decision_worker, mock_client + ): """Test decision task execution when handler raises an error.""" decision_task = self.create_mock_decision_task() - + # Mock the decision handler to raise an error - with patch.object(decision_worker, '_decision_handler') as mock_handler: + with patch.object(decision_worker, "_decision_handler") as mock_handler: mock_handler.handle_task = AsyncMock(side_effect=Exception("Handler error")) - + # Execute the task - should raise the exception with pytest.raises(Exception, match="Handler error"): await decision_worker._execute(decision_task) - + # Verify the handler was called mock_handler.handle_task.assert_called_once_with(decision_task) - def test_decision_worker_initialization(self, decision_worker, mock_client, registry): + def test_decision_worker_initialization( + self, decision_worker, mock_client, registry + ): """Test DecisionWorker initialization.""" assert decision_worker._client == mock_client assert decision_worker._task_list == "test-task-list" @@ -203,10 +222,12 @@ def test_decision_worker_initialization(self, decision_worker, mock_client, regi async def test_decision_worker_run(self, decision_worker, mock_client): """Test DecisionWorker run method.""" # Mock the poller to complete immediately - with patch.object(decision_worker._poller, 'run', new_callable=AsyncMock) as mock_poller_run: + with patch.object( + decision_worker._poller, "run", new_callable=AsyncMock + ) as mock_poller_run: # Run the worker await decision_worker.run() - + # Verify the poller was run mock_poller_run.assert_called_once() @@ -215,47 +236,50 @@ async def test_decision_worker_integration_flow(self, decision_worker, mock_clie """Test the complete integration flow from poll to execute.""" # Create a mock decision task decision_task = self.create_mock_decision_task() - + # Mock the poll to return the decision task mock_client.worker_stub.PollForDecisionTask.return_value = decision_task - + # Mock the decision handler - with patch.object(decision_worker, '_decision_handler') as mock_handler: + with patch.object(decision_worker, "_decision_handler") as mock_handler: mock_handler.handle_task = AsyncMock() - + # Test the complete flow # 1. Poll for task polled_task = await decision_worker._poll() assert polled_task == decision_task - + # 2. Execute the task await decision_worker._execute(polled_task) - + # 3. Verify the handler was called mock_handler.handle_task.assert_called_once_with(decision_task) @pytest.mark.asyncio - async def test_decision_worker_with_different_workflow_types(self, decision_worker, mock_client, registry): + async def test_decision_worker_with_different_workflow_types( + self, decision_worker, mock_client, registry + ): """Test decision worker with different workflow types.""" + # Add another workflow to the registry @registry.workflow class AnotherWorkflow: @workflow.run async def run(self, input_data): return f"another-processed: {input_data}" - + # Create decision tasks for different workflow types task1 = self.create_mock_decision_task(workflow_type="test_workflow") task2 = self.create_mock_decision_task(workflow_type="another_workflow") - + # Mock the decision handler - with patch.object(decision_worker, '_decision_handler') as mock_handler: + with patch.object(decision_worker, "_decision_handler") as mock_handler: mock_handler.handle_task = AsyncMock() - + # Execute both tasks await decision_worker._execute(task1) await decision_worker._execute(task2) - + # Verify both tasks were handled assert mock_handler.handle_task.call_count == 2 @@ -263,8 +287,10 @@ async def run(self, input_data): async def test_decision_worker_poll_timeout(self, decision_worker, mock_client): """Test decision worker polling with timeout.""" # Mock the poll to raise a timeout exception - mock_client.worker_stub.PollForDecisionTask.side_effect = asyncio.TimeoutError("Poll timeout") - + mock_client.worker_stub.PollForDecisionTask.side_effect = asyncio.TimeoutError( + "Poll timeout" + ) + # Run the poll - should handle timeout gracefully with pytest.raises(asyncio.TimeoutError): await decision_worker._poll() @@ -274,16 +300,16 @@ def test_decision_worker_options_handling(self, mock_client, registry): options = { "identity": "custom-worker", "max_concurrent_decision_task_execution_size": 5, - "decision_task_pollers": 3 + "decision_task_pollers": 3, } - + worker = DecisionWorker( client=mock_client, task_list="custom-task-list", registry=registry, - options=options + options=options, ) - + # Verify options were applied assert worker._identity == "custom-worker" assert worker._task_list == "custom-task-list" diff --git a/tests/cadence/worker/test_poller.py b/tests/cadence/worker/test_poller.py index 2cb3e00..ee891a4 100644 --- a/tests/cadence/worker/test_poller.py +++ b/tests/cadence/worker/test_poller.py @@ -21,6 +21,7 @@ async def test_poller(): assert incoming.empty() is True assert outgoing.empty() is True + @pytest.mark.asyncio async def test_poller_empty_task(): permits = asyncio.Semaphore(1) @@ -36,6 +37,7 @@ async def test_poller_empty_task(): assert result == "foo" task.cancel() + @pytest.mark.asyncio async def test_poller_num_tasks(): permits = asyncio.Semaphore(10) @@ -65,6 +67,7 @@ async def poll_func(): task.cancel() + @pytest.mark.asyncio async def test_poller_concurrency(): permits = asyncio.Semaphore(5) @@ -106,6 +109,7 @@ async def test_poller_poll_error(): done = asyncio.Event() call_count = 0 + async def poll_func(): nonlocal call_count call_count += 1 @@ -127,12 +131,14 @@ async def poll_func(): task.cancel() done.set() + @pytest.mark.asyncio async def test_poller_execute_error(): permits = asyncio.Semaphore(1) outgoing = asyncio.Queue() call_count = 0 + async def execute(item: str): nonlocal call_count call_count += 1 @@ -150,5 +156,3 @@ async def execute(item: str): assert result == "second" task.cancel() - - diff --git a/tests/cadence/worker/test_registry.py b/tests/cadence/worker/test_registry.py index bf6721e..0a61d8b 100644 --- a/tests/cadence/worker/test_registry.py +++ b/tests/cadence/worker/test_registry.py @@ -14,7 +14,7 @@ class TestRegistry: """Test registry functionality.""" - + def test_basic_registry_creation(self): """Test basic registry creation.""" reg = Registry() @@ -22,7 +22,7 @@ def test_basic_registry_creation(self): reg.get_workflow("nonexistent") with pytest.raises(KeyError): reg.get_activity("nonexistent") - + def test_basic_workflow_registration_and_retrieval(self): """Test basic registration and retrieval for class-based workflows.""" reg = Registry() @@ -50,7 +50,7 @@ def test_func(): func = reg.get_activity(test_func.name) assert func() == "test" - + def test_direct_call_behavior(self): reg = Registry() @@ -60,9 +60,9 @@ def test_func(): reg.register_activity(test_func) func = reg.get_activity("test_func") - + assert func() == "direct_call" - + def test_workflow_not_found_error(self): """Test KeyError is raised when workflow not found.""" reg = Registry() @@ -86,6 +86,7 @@ async def run(self): return "test" with pytest.raises(KeyError): + @reg.workflow(name="duplicate_test") class TestWorkflow2: @workflow.run @@ -101,6 +102,7 @@ def test_func(): return "test" with pytest.raises(KeyError): + @reg.activity(name="test_func") def test_func(): return "duplicate" @@ -119,9 +121,15 @@ def test_register_activities_interface(self): reg.register_activities(impl) - assert reg.get_activity(common_activities.ActivityInterface.do_something.name) is not None + assert ( + reg.get_activity(common_activities.ActivityInterface.do_something.name) + is not None + ) assert reg.get_activity("ActivityInterface.do_something") is not None - assert reg.get_activity(common_activities.ActivityInterface.do_something.name)() == "result" + assert ( + reg.get_activity(common_activities.ActivityInterface.do_something.name)() + == "result" + ) def test_register_activities_invalid_impl(self): impl = common_activities.InvalidImpl() @@ -130,7 +138,6 @@ def test_register_activities_invalid_impl(self): with pytest.raises(ValueError): reg.register_activities(impl) - def test_add(self): registry = Registry() registry.register_activity(common_activities.simple_fn) @@ -173,6 +180,7 @@ def test_class_workflow_validation_errors(self): # Test missing run method with pytest.raises(ValueError, match="No @workflow.run method found"): + @reg.workflow class MissingRunWorkflow: def some_method(self): @@ -180,6 +188,7 @@ def some_method(self): # Test duplicate run methods with pytest.raises(ValueError, match="Multiple @workflow.run methods found"): + @reg.workflow class DuplicateRunWorkflow: @workflow.run diff --git a/tests/cadence/worker/test_task_handler_integration.py b/tests/cadence/worker/test_task_handler_integration.py index daa36bb..5fec846 100644 --- a/tests/cadence/worker/test_task_handler_integration.py +++ b/tests/cadence/worker/test_task_handler_integration.py @@ -18,7 +18,7 @@ class TestTaskHandlerIntegration: """Integration tests for task handlers.""" - + @pytest.fixture def mock_client(self): """Create a mock client.""" @@ -28,13 +28,13 @@ def mock_client(self): client.worker_stub.RespondDecisionTaskFailed = AsyncMock() type(client).domain = PropertyMock(return_value="test_domain") return client - + @pytest.fixture def mock_registry(self): """Create a mock registry.""" registry = Mock(spec=Registry) return registry - + @pytest.fixture def handler(self, mock_client, mock_registry): """Create a DecisionTaskHandler instance.""" @@ -42,9 +42,9 @@ def handler(self, mock_client, mock_registry): client=mock_client, task_list="test_task_list", registry=mock_registry, - identity="test_identity" + identity="test_identity", ) - + @pytest.fixture def sample_decision_task(self): """Create a sample decision task.""" @@ -59,13 +59,14 @@ def sample_decision_task(self): task.started_event_id = 1 task.attempt = 1 return task - + @pytest.mark.asyncio - async def test_full_task_handling_flow_success(self, handler, sample_decision_task, mock_registry): + async def test_full_task_handling_flow_success( + self, handler, sample_decision_task, mock_registry + ): """Test the complete task handling flow from base handler through decision handler.""" # Create actual workflow definition - class MockWorkflow: @workflow.run async def run(self): @@ -74,26 +75,32 @@ async def run(self): workflow_opts = WorkflowDefinitionOptions(name="test_workflow") workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) mock_registry.get_workflow.return_value = workflow_definition - + # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine", + return_value=mock_engine, + ): # Use the base handler's handle_task method await handler.handle_task(sample_decision_task) - + # Verify the complete flow mock_registry.get_workflow.assert_called_once_with("TestWorkflow") mock_engine.process_decision.assert_called_once_with(sample_decision_task) handler._client.worker_stub.RespondDecisionTaskCompleted.assert_called_once() - + @pytest.mark.asyncio - async def test_full_task_handling_flow_with_error(self, handler, sample_decision_task, mock_registry): + async def test_full_task_handling_flow_with_error( + self, handler, sample_decision_task, mock_registry + ): """Test the complete task handling flow when an error occurs.""" + # Create actual workflow definition class MockWorkflow: @workflow.run @@ -103,25 +110,35 @@ async def run(self): workflow_opts = WorkflowDefinitionOptions(name="test_workflow") workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) mock_registry.get_workflow.return_value = workflow_definition - + # Mock workflow engine to raise an error mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute - mock_engine.process_decision = AsyncMock(side_effect=RuntimeError("Workflow processing failed")) - - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + mock_engine.process_decision = AsyncMock( + side_effect=RuntimeError("Workflow processing failed") + ) + + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine", + return_value=mock_engine, + ): # Use the base handler's handle_task method await handler.handle_task(sample_decision_task) - + # Verify error handling handler._client.worker_stub.RespondDecisionTaskFailed.assert_called_once() - call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][ + 0 + ] assert call_args.task_token == sample_decision_task.task_token assert call_args.identity == handler._identity - + @pytest.mark.asyncio - async def test_context_activation_integration(self, handler, sample_decision_task, mock_registry): + async def test_context_activation_integration( + self, handler, sample_decision_task, mock_registry + ): """Test that context activation works correctly in the integration.""" + # Create actual workflow definition class MockWorkflow: @workflow.run @@ -131,35 +148,43 @@ async def run(self): workflow_opts = WorkflowDefinitionOptions(name="test_workflow") workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) mock_registry.get_workflow.return_value = workflow_definition - + # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - + # Track if context is activated context_activated = False - + def track_context_activation(): nonlocal context_activated context_activated = True - - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): - with patch('cadence._internal.workflow.workflow_engine.Context') as mock_context_class: + + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine", + return_value=mock_engine, + ): + with patch( + "cadence._internal.workflow.workflow_engine.Context" + ) as mock_context_class: mock_context = Mock() - mock_context._activate = Mock(return_value=contextmanager(lambda: track_context_activation())()) + mock_context._activate = Mock( + return_value=contextmanager(lambda: track_context_activation())() + ) mock_context_class.return_value = mock_context - + await handler.handle_task(sample_decision_task) - + # Verify context was activated assert context_activated - + @pytest.mark.asyncio async def test_multiple_workflow_executions(self, handler, mock_registry): """Test handling multiple workflow executions creates new engines for each.""" + # Create actual workflow definition class MockWorkflow: @workflow.run @@ -169,7 +194,7 @@ async def run(self): workflow_opts = WorkflowDefinitionOptions(name="test_workflow") workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) mock_registry.get_workflow.return_value = workflow_definition - + # Create multiple decision tasks for different workflows task1 = Mock(spec=PollForDecisionTaskResponse) task1.task_token = b"task1_token" @@ -180,7 +205,7 @@ async def run(self): task1.workflow_type.name = "TestWorkflow" task1.started_event_id = 1 task1.attempt = 1 - + task2 = Mock(spec=PollForDecisionTaskResponse) task2.task_token = b"task2_token" task2.workflow_execution = Mock() @@ -190,30 +215,36 @@ async def run(self): task2.workflow_type.name = "TestWorkflow" task2.started_event_id = 2 task2.attempt = 1 - + # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute - + mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] - + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class: + + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine", + return_value=mock_engine, + ) as mock_engine_class: # Process both tasks await handler.handle_task(task1) await handler.handle_task(task2) - + # Verify engines were created for each task assert mock_engine_class.call_count == 2 - + # Verify both tasks were processed assert mock_engine.process_decision.call_count == 2 - + @pytest.mark.asyncio - async def test_workflow_engine_creation_integration(self, handler, sample_decision_task, mock_registry): + async def test_workflow_engine_creation_integration( + self, handler, sample_decision_task, mock_registry + ): """Test workflow engine creation integration.""" + # Create actual workflow definition class MockWorkflow: @workflow.run @@ -223,25 +254,31 @@ async def run(self): workflow_opts = WorkflowDefinitionOptions(name="test_workflow") workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) mock_registry.get_workflow.return_value = workflow_definition - + # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class: + + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine", + return_value=mock_engine, + ) as mock_engine_class: # Process task to create engine await handler.handle_task(sample_decision_task) - + # Verify engine was created and used mock_engine_class.assert_called_once() mock_engine.process_decision.assert_called_once_with(sample_decision_task) - + @pytest.mark.asyncio - async def test_error_handling_with_context_cleanup(self, handler, sample_decision_task, mock_registry): + async def test_error_handling_with_context_cleanup( + self, handler, sample_decision_task, mock_registry + ): """Test that context cleanup happens even when errors occur.""" + # Create actual workflow definition class MockWorkflow: @workflow.run @@ -251,38 +288,47 @@ async def run(self): workflow_opts = WorkflowDefinitionOptions(name="test_workflow") workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) mock_registry.get_workflow.return_value = workflow_definition - + # Mock workflow engine to raise an error mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute - mock_engine.process_decision = AsyncMock(side_effect=RuntimeError("Workflow processing failed")) - + mock_engine.process_decision = AsyncMock( + side_effect=RuntimeError("Workflow processing failed") + ) + # Track context cleanup context_cleaned_up = False - + def track_context_cleanup(): nonlocal context_cleaned_up context_cleaned_up = True - - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): - with patch('cadence._internal.workflow.workflow_engine.Context') as mock_context_class: + + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine", + return_value=mock_engine, + ): + with patch( + "cadence._internal.workflow.workflow_engine.Context" + ) as mock_context_class: mock_context = Mock() - mock_context._activate = Mock(return_value=contextmanager(lambda: track_context_cleanup())()) + mock_context._activate = Mock( + return_value=contextmanager(lambda: track_context_cleanup())() + ) mock_context_class.return_value = mock_context - + await handler.handle_task(sample_decision_task) - + # Verify context was cleaned up even after error assert context_cleaned_up - + # Verify error was handled handler._client.worker_stub.RespondDecisionTaskFailed.assert_called_once() - + @pytest.mark.asyncio async def test_concurrent_task_handling(self, handler, mock_registry): """Test handling multiple tasks concurrently.""" import asyncio - + # Create actual workflow definition class MockWorkflow: @workflow.run @@ -292,7 +338,7 @@ async def run(self): workflow_opts = WorkflowDefinitionOptions(name="test_workflow") workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) mock_registry.get_workflow.return_value = workflow_definition - + # Create multiple tasks tasks = [] for i in range(3): @@ -306,18 +352,21 @@ async def run(self): task.started_event_id = i + 1 task.attempt = 1 tasks.append(task) - + # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + + with patch( + "cadence.worker._decision_task_handler.WorkflowEngine", + return_value=mock_engine, + ): # Process all tasks concurrently await asyncio.gather(*[handler.handle_task(task) for task in tasks]) - + # Verify all tasks were processed assert mock_engine.process_decision.call_count == 3 assert handler._client.worker_stub.RespondDecisionTaskCompleted.call_count == 3 diff --git a/tests/cadence/worker/test_worker.py b/tests/cadence/worker/test_worker.py index 951ae6d..ebc3f65 100644 --- a/tests/cadence/worker/test_worker.py +++ b/tests/cadence/worker/test_worker.py @@ -4,7 +4,10 @@ from unittest.mock import AsyncMock, Mock, PropertyMock -from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskRequest, PollForActivityTaskRequest +from cadence.api.v1.service_worker_pb2 import ( + PollForDecisionTaskRequest, + PollForActivityTaskRequest, +) from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind from cadence.client import Client from cadence.worker import Worker, Registry @@ -29,7 +32,14 @@ async def poll(_, timeout=0.0): type(client).domain = PropertyMock(return_value="domain") type(client).identity = PropertyMock(return_value="identity") - worker = Worker(client, "task_list", Registry(), activity_task_pollers=1, decision_task_pollers=1, identity="identity") + worker = Worker( + client, + "task_list", + Registry(), + activity_task_pollers=1, + decision_task_pollers=1, + identity="identity", + ) task = asyncio.create_task(worker.run()) @@ -39,14 +49,24 @@ async def poll(_, timeout=0.0): with pytest.raises(asyncio.CancelledError): await task - worker_stub.PollForDecisionTask.assert_called_once_with(PollForDecisionTaskRequest( - domain="domain", - identity="identity", - task_list=TaskList(name="task_list", kind=TaskListKind.TASK_LIST_KIND_NORMAL), - ), timeout=60.0) + worker_stub.PollForDecisionTask.assert_called_once_with( + PollForDecisionTaskRequest( + domain="domain", + identity="identity", + task_list=TaskList( + name="task_list", kind=TaskListKind.TASK_LIST_KIND_NORMAL + ), + ), + timeout=60.0, + ) - worker_stub.PollForActivityTask.assert_called_once_with(PollForActivityTaskRequest( - domain="domain", - identity="identity", - task_list=TaskList(name="task_list", kind=TaskListKind.TASK_LIST_KIND_NORMAL), - ), timeout=60.0) \ No newline at end of file + worker_stub.PollForActivityTask.assert_called_once_with( + PollForActivityTaskRequest( + domain="domain", + identity="identity", + task_list=TaskList( + name="task_list", kind=TaskListKind.TASK_LIST_KIND_NORMAL + ), + ), + timeout=60.0, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 5899674..4983c71 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,10 @@ - - ENABLE_INTEGRATION_TESTS = "--integration-tests" + # Need to define the option in the root conftest.py file def pytest_addoption(parser): - parser.addoption(ENABLE_INTEGRATION_TESTS, action="store_true", - help="enables running integration tests, which rely on docker and docker-compose") \ No newline at end of file + parser.addoption( + ENABLE_INTEGRATION_TESTS, + action="store_true", + help="enables running integration tests, which rely on docker and docker-compose", + ) diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 0833e24..a4486eb 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -12,22 +12,28 @@ from tests.conftest import ENABLE_INTEGRATION_TESTS from tests.integration_tests.helper import CadenceHelper, DOMAIN_NAME + # Run tests in this directory and lower only if integration tests are enabled def pytest_runtest_setup(item): if not item.config.getoption(ENABLE_INTEGRATION_TESTS): pytest.skip(f"{ENABLE_INTEGRATION_TESTS} not enabled") + @pytest.fixture(scope="session") def docker_compose_file(pytestconfig): - return os.path.join(str(pytestconfig.rootdir), "tests", "integration_tests", "docker-compose.yml") + return os.path.join( + str(pytestconfig.rootdir), "tests", "integration_tests", "docker-compose.yml" + ) + @pytest.fixture(scope="session") def client_options(docker_ip: str, docker_services: Services) -> ClientOptions: return ClientOptions( domain=DOMAIN_NAME, - target=f'{docker_ip}:{docker_services.port_for("cadence", 7833)}', + target=f"{docker_ip}:{docker_services.port_for('cadence', 7833)}", ) + # We can't pass around Client objects between tests/fixtures without changing our pytest-asyncio version # to ensure that they use the same event loop. # Instead, we can wait for the server to be ready, create the common domain, and then provide a helper capable @@ -40,8 +46,10 @@ async def helper(client_options: ClientOptions) -> CadenceHelper: async with asyncio.timeout(120): await client.ready() - await client.domain_stub.RegisterDomain(RegisterDomainRequest( - name=DOMAIN_NAME, - workflow_execution_retention_period=from_timedelta(timedelta(days=1)), - )) + await client.domain_stub.RegisterDomain( + RegisterDomainRequest( + name=DOMAIN_NAME, + workflow_execution_retention_period=from_timedelta(timedelta(days=1)), + ) + ) return CadenceHelper(client_options) diff --git a/tests/integration_tests/helper.py b/tests/integration_tests/helper.py index 06fb5a0..5dac03d 100644 --- a/tests/integration_tests/helper.py +++ b/tests/integration_tests/helper.py @@ -8,4 +8,4 @@ def __init__(self, options: ClientOptions): self.options = options def client(self): - return Client(**self.options) \ No newline at end of file + return Client(**self.options) diff --git a/tests/integration_tests/test_client.py b/tests/integration_tests/test_client.py index 7acba3d..3f96b9b 100644 --- a/tests/integration_tests/test_client.py +++ b/tests/integration_tests/test_client.py @@ -1,25 +1,37 @@ from datetime import timedelta import pytest -from cadence.api.v1.service_domain_pb2 import DescribeDomainRequest, DescribeDomainResponse +from cadence.api.v1.service_domain_pb2 import ( + DescribeDomainRequest, + DescribeDomainResponse, +) from cadence.api.v1.service_workflow_pb2 import DescribeWorkflowExecutionRequest from cadence.api.v1.common_pb2 import WorkflowExecution + from cadence.error import EntityNotExistsError from tests.integration_tests.helper import CadenceHelper, DOMAIN_NAME + @pytest.mark.usefixtures("helper") async def test_domain_exists(helper: CadenceHelper): async with helper.client() as client: - response: DescribeDomainResponse = await client.domain_stub.DescribeDomain(DescribeDomainRequest(name=DOMAIN_NAME)) + response: DescribeDomainResponse = await client.domain_stub.DescribeDomain( + DescribeDomainRequest(name=DOMAIN_NAME) + ) assert response.domain.name == DOMAIN_NAME + @pytest.mark.usefixtures("helper") async def test_domain_not_exists(helper: CadenceHelper): with pytest.raises(EntityNotExistsError): async with helper.client() as client: - await client.domain_stub.DescribeDomain(DescribeDomainRequest(name="unknown-domain")) + await client.domain_stub.DescribeDomain( + DescribeDomainRequest(name="unknown-domain") + ) + # Worker Stub Tests + @pytest.mark.usefixtures("helper") async def test_worker_stub_accessible(helper: CadenceHelper): """Test that worker_stub is properly initialized and accessible.""" @@ -27,14 +39,17 @@ async def test_worker_stub_accessible(helper: CadenceHelper): assert client.worker_stub is not None # Verify it's the correct type from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub + assert isinstance(client.worker_stub, WorkerAPIStub) + # Workflow Stub Tests + @pytest.mark.usefixtures("helper") async def test_workflow_stub_start_and_describe(helper: CadenceHelper): """Comprehensive test for workflow start and describe operations. - + This integration test verifies: 1. Starting a workflow execution via workflow_stub 2. Describing the workflow execution @@ -51,7 +66,7 @@ async def test_workflow_stub_start_and_describe(helper: CadenceHelper): workflow_id = "test-workflow-describe-456" execution_timeout = timedelta(minutes=5) task_timeout = timedelta(seconds=10) # Default value - + # Start a workflow with specific parameters execution = await client.start_workflow( workflow_type, @@ -60,7 +75,7 @@ async def test_workflow_stub_start_and_describe(helper: CadenceHelper): task_start_to_close_timeout=task_timeout, workflow_id=workflow_id, ) - + # Describe the workflow execution describe_request = DescribeWorkflowExecutionRequest( domain=DOMAIN_NAME, @@ -69,40 +84,55 @@ async def test_workflow_stub_start_and_describe(helper: CadenceHelper): run_id=execution.run_id, ), ) - - response = await client.workflow_stub.DescribeWorkflowExecution(describe_request) - + + response = await client.workflow_stub.DescribeWorkflowExecution( + describe_request + ) + # Assert workflow execution info matches assert response is not None, "DescribeWorkflowExecution returned None" - assert response.workflow_execution_info is not None, "workflow_execution_info is None" - + assert response.workflow_execution_info is not None, ( + "workflow_execution_info is None" + ) + # Verify workflow execution identifiers wf_exec = response.workflow_execution_info.workflow_execution - assert wf_exec.workflow_id == workflow_id, \ + assert wf_exec.workflow_id == workflow_id, ( f"workflow_id mismatch: expected {workflow_id}, got {wf_exec.workflow_id}" - assert wf_exec.run_id == execution.run_id, \ + ) + assert wf_exec.run_id == execution.run_id, ( f"run_id mismatch: expected {execution.run_id}, got {wf_exec.run_id}" - + ) + # Verify workflow type - assert response.workflow_execution_info.type.name == workflow_type, \ + assert response.workflow_execution_info.type.name == workflow_type, ( f"workflow_type mismatch: expected {workflow_type}, got {response.workflow_execution_info.type.name}" - + ) + # Verify task list - assert response.workflow_execution_info.task_list == task_list_name, \ + assert response.workflow_execution_info.task_list == task_list_name, ( f"task_list mismatch: expected {task_list_name}, got {response.workflow_execution_info.task_list}" - + ) + # Verify execution configuration - assert response.execution_configuration is not None, "execution_configuration is None" - + assert response.execution_configuration is not None, ( + "execution_configuration is None" + ) + # Verify task list in configuration - assert response.execution_configuration.task_list.name == task_list_name, \ + assert response.execution_configuration.task_list.name == task_list_name, ( f"config task_list mismatch: expected {task_list_name}, got {response.execution_configuration.task_list.name}" - + ) + # Verify timeouts exec_timeout_seconds = response.execution_configuration.execution_start_to_close_timeout.ToSeconds() - assert exec_timeout_seconds == execution_timeout.total_seconds(), \ + assert exec_timeout_seconds == execution_timeout.total_seconds(), ( f"execution_start_to_close_timeout mismatch: expected {execution_timeout.total_seconds()}s, got {exec_timeout_seconds}s" - - task_timeout_seconds = response.execution_configuration.task_start_to_close_timeout.ToSeconds() - assert task_timeout_seconds == task_timeout.total_seconds(), \ + ) + + task_timeout_seconds = ( + response.execution_configuration.task_start_to_close_timeout.ToSeconds() + ) + assert task_timeout_seconds == task_timeout.total_seconds(), ( f"task_start_to_close_timeout mismatch: expected {task_timeout.total_seconds()}s, got {task_timeout_seconds}s" + ) From e377fe01e65c26ebf779dfd1fd23cd0acb63d932 Mon Sep 17 00:00:00 2001 From: Shijie Sheng Date: Wed, 29 Oct 2025 11:05:32 -0700 Subject: [PATCH 2/2] fix git action as well Signed-off-by: Shijie Sheng --- .github/workflows/ci_checks.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci_checks.yml b/.github/workflows/ci_checks.yml index f3cf642..a8c06a1 100644 --- a/.github/workflows/ci_checks.yml +++ b/.github/workflows/ci_checks.yml @@ -34,6 +34,7 @@ jobs: - name: Run Ruff linter run: | uv tool run ruff check + uv tool run ruff format --check type-check: name: Type Safety Check