diff --git a/cadence/_internal/activity/_activity_executor.py b/cadence/_internal/activity/_activity_executor.py index 6c2a7c2..9b5ff7f 100644 --- a/cadence/_internal/activity/_activity_executor.py +++ b/cadence/_internal/activity/_activity_executor.py @@ -73,7 +73,7 @@ async def _report_failure( _logger.exception("Exception reporting activity failure") async def _report_success(self, task: PollForActivityTaskResponse, result: Any): - as_payload = await self._data_converter.to_data([result]) + as_payload = self._data_converter.to_data([result]) try: await self._client.worker_stub.RespondActivityTaskCompleted( diff --git a/cadence/_internal/activity/_context.py b/cadence/_internal/activity/_context.py index 22f7f85..6839070 100644 --- a/cadence/_internal/activity/_context.py +++ b/cadence/_internal/activity/_context.py @@ -19,13 +19,13 @@ def __init__( self._activity_fn = activity_fn async def execute(self, payload: Payload) -> Any: - params = await self._to_params(payload) + params = self._to_params(payload) with self._activate(): return await self._activity_fn(*params) - async def _to_params(self, payload: Payload) -> list[Any]: + def _to_params(self, payload: Payload) -> list[Any]: type_hints = [param.type_hint for param in self._activity_fn.params] - return await self._client.data_converter.from_data(payload, type_hints) + return self._client.data_converter.from_data(payload, type_hints) def client(self) -> Client: return self._client @@ -46,7 +46,7 @@ def __init__( self._executor = executor async def execute(self, payload: Payload) -> Any: - params = await self._to_params(payload) + params = self._to_params(payload) loop = asyncio.get_running_loop() return await loop.run_in_executor(self._executor, self._run, params) diff --git a/cadence/_internal/decision_state_machine.py b/cadence/_internal/decision_state_machine.py deleted file mode 100644 index a5b4db9..0000000 --- a/cadence/_internal/decision_state_machine.py +++ /dev/null @@ -1,950 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from enum import Enum -from typing import Dict, List, Optional, Callable, TypedDict, Literal - -from cadence.api.v1 import ( - decision_pb2 as decision, - history_pb2 as history, - common_pb2 as common, -) - - -class DecisionState(Enum): - """Lifecycle states for a decision-producing state machine instance.""" - - CREATED = 0 - DECISION_SENT = 1 - CANCELED_BEFORE_INITIATED = 2 - INITIATED = 3 - STARTED = 4 - CANCELED_AFTER_INITIATED = 5 - CANCELED_AFTER_STARTED = 6 - CANCELLATION_DECISION_SENT = 7 - COMPLETED_AFTER_CANCELLATION_DECISION_SENT = 8 - COMPLETED = 9 - - @classmethod - def to_string(cls, state: DecisionState) -> str: - mapping = { - DecisionState.CREATED: "Created", - DecisionState.DECISION_SENT: "DecisionSent", - DecisionState.CANCELED_BEFORE_INITIATED: "CanceledBeforeInitiated", - DecisionState.INITIATED: "Initiated", - DecisionState.STARTED: "Started", - DecisionState.CANCELED_AFTER_INITIATED: "CanceledAfterInitiated", - DecisionState.CANCELED_AFTER_STARTED: "CanceledAfterStarted", - DecisionState.CANCELLATION_DECISION_SENT: "CancellationDecisionSent", - DecisionState.COMPLETED_AFTER_CANCELLATION_DECISION_SENT: "CompletedAfterCancellationDecisionSent", - DecisionState.COMPLETED: "Completed", - } - return mapping.get(state, "Unknown") - - -class DecisionType(Enum): - """Types of decisions that can be made by state machines.""" - - ACTIVITY = 0 - CHILD_WORKFLOW = 1 - CANCELLATION = 2 - MARKER = 3 - TIMER = 4 - SIGNAL = 5 - UPSERT_SEARCH_ATTRIBUTES = 6 - - @classmethod - def to_string(cls, dt: DecisionType) -> str: - mapping = { - DecisionType.ACTIVITY: "Activity", - DecisionType.CHILD_WORKFLOW: "ChildWorkflow", - DecisionType.CANCELLATION: "Cancellation", - DecisionType.MARKER: "Marker", - DecisionType.TIMER: "Timer", - DecisionType.SIGNAL: "Signal", - DecisionType.UPSERT_SEARCH_ATTRIBUTES: "UpsertSearchAttributes", - } - return mapping.get(dt, "Unknown") - - -@dataclass(frozen=True) -class DecisionId: - decision_type: DecisionType - id: str - - def __str__(self) -> str: - return ( - f"DecisionType: {DecisionType.to_string(self.decision_type)}, ID: {self.id}" - ) - - -@dataclass -class StateTransition: - """Represents a state transition with associated actions.""" - - next_state: Optional[DecisionState] - action: Optional[ - Callable[["BaseDecisionStateMachine", history.HistoryEvent], None] - ] = None - condition: Optional[ - Callable[["BaseDecisionStateMachine", history.HistoryEvent], bool] - ] = None - - -class TransitionInfo(TypedDict): - type: Literal[ - "initiated", - "started", - "completion", - "canceled", - "cancel_initiated", - "cancel_failed", - "initiation_failed", - ] - decision_type: DecisionType - transition: StateTransition - - -decision_state_transition_map: Dict[str, TransitionInfo] = { - "activity_task_scheduled_event_attributes": { - "type": "initiated", - "decision_type": DecisionType.ACTIVITY, - "transition": StateTransition(next_state=DecisionState.INITIATED), - }, - "activity_task_started_event_attributes": { - "type": "started", - "decision_type": DecisionType.ACTIVITY, - "transition": StateTransition(next_state=DecisionState.STARTED), - }, - "activity_task_completed_event_attributes": { - "type": "completion", - "decision_type": DecisionType.ACTIVITY, - "transition": StateTransition( - next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), - ), - }, - "activity_task_failed_event_attributes": { - "type": "completion", - "decision_type": DecisionType.ACTIVITY, - "transition": StateTransition( - next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), - ), - }, - "activity_task_timed_out_event_attributes": { - "type": "completion", - "decision_type": DecisionType.ACTIVITY, - "transition": StateTransition( - next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), - ), - }, - "activity_task_cancel_requested_event_attributes": { - "type": "cancel_initiated", - "decision_type": DecisionType.CANCELLATION, - "transition": StateTransition( - next_state=None, - action=lambda self, event: setattr(self, "_cancel_requested", True), - ), - }, - "activity_task_canceled_event_attributes": { - "type": "canceled", - "decision_type": DecisionType.ACTIVITY, - "transition": StateTransition( - next_state=DecisionState.CANCELED_AFTER_INITIATED - ), - }, - "request_cancel_activity_task_failed_event_attributes": { - "type": "cancel_failed", - "decision_type": DecisionType.CANCELLATION, - "transition": StateTransition( - next_state=None, - action=lambda self, event: setattr(self, "_cancel_emitted", False), - ), - }, - "timer_started_event_attributes": { - "type": "initiated", - "decision_type": DecisionType.TIMER, - "transition": StateTransition(next_state=DecisionState.INITIATED), - }, - "timer_fired_event_attributes": { - "type": "completion", - "decision_type": DecisionType.TIMER, - "transition": StateTransition( - next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), - ), - }, - "timer_canceled_event_attributes": { - "type": "canceled", - "decision_type": DecisionType.TIMER, - "transition": StateTransition( - next_state=DecisionState.CANCELED_AFTER_INITIATED - ), - }, - "cancel_timer_failed_event_attributes": { - "type": "cancel_failed", - "decision_type": DecisionType.CANCELLATION, - "transition": StateTransition( - next_state=None, - action=lambda self, event: setattr(self, "_cancel_emitted", False), - ), - }, - "start_child_workflow_execution_initiated_event_attributes": { - "type": "initiated", - "decision_type": DecisionType.CHILD_WORKFLOW, - "transition": StateTransition(next_state=DecisionState.INITIATED), - }, - "child_workflow_execution_started_event_attributes": { - "type": "started", - "decision_type": DecisionType.CHILD_WORKFLOW, - "transition": StateTransition(next_state=DecisionState.STARTED), - }, - "child_workflow_execution_completed_event_attributes": { - "type": "completion", - "decision_type": DecisionType.CHILD_WORKFLOW, - "transition": StateTransition( - next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), - ), - }, - "child_workflow_execution_failed_event_attributes": { - "type": "completion", - "decision_type": DecisionType.CHILD_WORKFLOW, - "transition": StateTransition( - next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), - ), - }, - "child_workflow_execution_timed_out_event_attributes": { - "type": "completion", - "decision_type": DecisionType.CHILD_WORKFLOW, - "transition": StateTransition( - next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), - ), - }, - "child_workflow_execution_canceled_event_attributes": { - "type": "canceled", - "decision_type": DecisionType.CHILD_WORKFLOW, - "transition": StateTransition( - next_state=DecisionState.CANCELED_AFTER_INITIATED - ), - }, - "child_workflow_execution_terminated_event_attributes": { - "type": "canceled", - "decision_type": DecisionType.CHILD_WORKFLOW, - "transition": StateTransition( - next_state=DecisionState.CANCELED_AFTER_INITIATED - ), - }, - "start_child_workflow_execution_failed_event_attributes": { - "type": "initiation_failed", - "decision_type": DecisionType.CHILD_WORKFLOW, - "transition": StateTransition( - next_state=DecisionState.COMPLETED, - action=lambda self, event: setattr(self, "status", DecisionState.COMPLETED), - ), - }, -} - - -class BaseDecisionStateMachine: - """Base class for state machines that may emit one or more decisions over time. - - Subclasses are responsible for mapping workflow history events into state - transitions and producing the next set of decisions when queried. - """ - - # Common fields that subclasses may use - scheduled_event_id: Optional[int] = None - started_event_id: Optional[int] = None - - def get_id(self) -> str: - raise NotImplementedError - - def _get_initiated_event_attr_name(self) -> str: - """Return the protobuf attribute name for initiated events.""" - raise NotImplementedError - - def _get_started_event_attr_name(self) -> str: - """Return the protobuf attribute name for started events.""" - raise NotImplementedError - - def _get_completion_event_attr_names(self) -> List[str]: - """Return the protobuf attribute names for completion events.""" - raise NotImplementedError - - def _get_cancel_initiated_event_attr_name(self) -> str: - """Return the protobuf attribute name for cancel initiated events.""" - raise NotImplementedError - - def _get_cancel_failed_event_attr_name(self) -> str: - """Return the protobuf attribute name for cancel failed events.""" - raise NotImplementedError - - def _get_canceled_event_attr_names(self) -> List[str]: - """Return the protobuf attribute names for canceled events.""" - raise NotImplementedError - - def _get_id_field_name(self) -> str: - """Return the field name used to identify this decision in events.""" - raise NotImplementedError - - def _get_event_id_field_name(self) -> str: - """Return the field name used to track event IDs.""" - return "scheduled_event_id" # Default, can be overridden - - def _should_handle_event( - self, event: history.HistoryEvent, attr_name: str, id_field: str - ) -> bool: - """Generic check if this event should be handled by this machine.""" - attr = getattr(event, attr_name, None) - if attr is None: - return False - - # Check if the ID matches - event_id = getattr(attr, id_field, None) - machine_id = getattr(self, self._get_id_field_name(), None) - return event_id == machine_id - - def _should_handle_event_by_event_id( - self, event: history.HistoryEvent, attr_name: str, event_id_field: str - ) -> bool: - """Generic check if this event should be handled by this machine based on event ID.""" - attr = getattr(event, attr_name, None) - if attr is None: - return False - - # Check if the event ID matches our tracked event ID - event_id = getattr(attr, event_id_field, None) - tracked_event_id = getattr(self, self._get_event_id_field_name(), None) - - return event_id == tracked_event_id - - def _default_initiated_action(self, event: history.HistoryEvent) -> None: - """Default action for initiated events.""" - self.status = DecisionState.INITIATED - event_id_field = self._get_event_id_field_name() - setattr(self, event_id_field, event.event_id) - - def _default_started_action(self, event: history.HistoryEvent) -> None: - """Default action for started events.""" - self.status = DecisionState.STARTED - if hasattr(self, "started_event_id"): - self.started_event_id = event.event_id - - def _default_completion_action( - self, event: history.HistoryEvent, attr_name: str - ) -> None: - """Default action for completion events.""" - self.status = DecisionState.COMPLETED - - def _default_cancel_action(self, event: history.HistoryEvent) -> None: - """Default action for cancel events.""" - if self.status == DecisionState.INITIATED: - self.status = DecisionState.CANCELED_AFTER_INITIATED - elif self.status == DecisionState.STARTED: - self.status = DecisionState.CANCELED_AFTER_INITIATED - else: - self.status = DecisionState.CANCELED_AFTER_INITIATED - - def _default_cancel_initiated_action(self, event: history.HistoryEvent) -> None: - """Default action for cancel initiated events.""" - if hasattr(self, "_cancel_requested"): - self._cancel_requested = True - - def _default_cancel_failed_action(self, event: history.HistoryEvent) -> None: - """Default action for cancel failed events.""" - if hasattr(self, "_cancel_emitted"): - self._cancel_emitted = False - - def handle_event(self, event: history.HistoryEvent, event_type: str) -> None: - """Generic event handler that uses the global transition map to determine state changes. - - Args: - event: The history event to process - event_type: The type of event (e.g., 'initiated', 'started', 'completion', etc.) - """ - if event_type == "initiated": - self._handle_initiated_event(event) - elif event_type == "started": - self._handle_started_event(event) - elif event_type == "completion": - self._handle_completion_event(event) - elif event_type == "cancel_initiated": - self._handle_cancel_initiated_event(event) - elif event_type == "cancel_failed": - self._handle_cancel_failed_event(event) - elif event_type == "canceled": - self._handle_canceled_event(event) - elif event_type == "initiation_failed": - self._handle_initiation_failed_event(event) - - def _handle_initiated_event(self, event: history.HistoryEvent) -> None: - """Handle initiated events using the global transition map.""" - attr_name = self._get_initiated_event_attr_name() - id_field = self._get_id_field_name() - - if not self._should_handle_event(event, attr_name, id_field): - return - - transition_info = decision_state_transition_map.get(attr_name) - if transition_info and transition_info["type"] == "initiated": - transition = transition_info["transition"] - if transition.action: - transition.action(self, event) - else: - self._default_initiated_action(event) - - def _handle_started_event(self, event: history.HistoryEvent) -> None: - """Handle started events using the global transition map.""" - attr_name = self._get_started_event_attr_name() - if not attr_name: # Some decision types don't have started events - return - - # Check if this event has the started attribute - if hasattr(event, attr_name): - # Determine the appropriate event ID field based on the decision type - if attr_name == "activity_task_started_event_attributes": - # Activity started events use scheduled_event_id - event_id_field = "scheduled_event_id" - elif attr_name == "child_workflow_execution_started_event_attributes": - # Child workflow started events use initiated_event_id - event_id_field = "initiated_event_id" - else: - event_id_field = self._get_event_id_field_name() - - if not self._should_handle_event_by_event_id( - event, attr_name, event_id_field - ): - return - - transition_info = decision_state_transition_map.get(attr_name) - if transition_info and transition_info["type"] == "started": - transition = transition_info["transition"] - if transition.action: - transition.action(self, event) - else: - self._default_started_action(event) - - def _handle_completion_event(self, event: history.HistoryEvent) -> None: - """Handle completion events using the global transition map.""" - attr_names = self._get_completion_event_attr_names() - - for attr_name in attr_names: - # Check if this event has the completion attribute - if hasattr(event, attr_name): - # Determine the appropriate event ID field based on the decision type - if attr_name == "timer_fired_event_attributes": - # Timer completion events use started_event_id - event_id_field = "started_event_id" - elif attr_name in [ - "activity_task_completed_event_attributes", - "activity_task_failed_event_attributes", - "activity_task_timed_out_event_attributes", - ]: - # Activity completion events use scheduled_event_id - event_id_field = "scheduled_event_id" - elif attr_name in [ - "child_workflow_execution_completed_event_attributes", - "child_workflow_execution_failed_event_attributes", - "child_workflow_execution_timed_out_event_attributes", - ]: - # Child workflow completion events use initiated_event_id - event_id_field = "initiated_event_id" - else: - # Default case - event_id_field = self._get_event_id_field_name() - - # Check if this event should be handled by this machine - if self._should_handle_event_by_event_id( - event, attr_name, event_id_field - ): - transition_info = decision_state_transition_map.get(attr_name) - if transition_info and transition_info["type"] == "completion": - transition = transition_info["transition"] - if transition.action: - transition.action(self, event) - else: - self._default_completion_action(event, attr_name) - break - - def _handle_cancel_initiated_event(self, event: history.HistoryEvent) -> None: - """Handle cancel initiated events using the global transition map.""" - attr_name = self._get_cancel_initiated_event_attr_name() - if not attr_name: # Some decision types don't have cancel initiated events - return - - id_field = self._get_id_field_name() - if not self._should_handle_event(event, attr_name, id_field): - return - - transition_info = decision_state_transition_map.get(attr_name) - if transition_info and transition_info["type"] == "cancel_initiated": - transition = transition_info["transition"] - if transition.action: - transition.action(self, event) - else: - self._default_cancel_initiated_action(event) - - def _handle_cancel_failed_event(self, event: history.HistoryEvent) -> None: - """Handle cancel failed events using the global transition map.""" - attr_name = self._get_cancel_failed_event_attr_name() - if not attr_name: # Some decision types don't have cancel failed events - return - - id_field = self._get_id_field_name() - if not self._should_handle_event(event, attr_name, id_field): - return - - transition_info = decision_state_transition_map.get(attr_name) - if transition_info and transition_info["type"] == "cancel_failed": - transition = transition_info["transition"] - if transition.action: - transition.action(self, event) - else: - self._default_cancel_failed_action(event) - - def _handle_canceled_event(self, event: history.HistoryEvent) -> None: - """Handle canceled events using the global transition map.""" - attr_names = self._get_canceled_event_attr_names() - - for attr_name in attr_names: - # Check if this event has the canceled attribute - if hasattr(event, attr_name): - # Determine the appropriate event ID field based on the decision type - if attr_name == "timer_canceled_event_attributes": - # Timer canceled events use started_event_id - event_id_field = "started_event_id" - elif attr_name == "activity_task_canceled_event_attributes": - # Activity canceled events use scheduled_event_id - event_id_field = "scheduled_event_id" - elif attr_name in [ - "child_workflow_execution_canceled_event_attributes", - "child_workflow_execution_terminated_event_attributes", - ]: - # Child workflow canceled events use initiated_event_id - event_id_field = "initiated_event_id" - else: - # Default case - event_id_field = self._get_event_id_field_name() - - # Check if this event should be handled by this machine - if self._should_handle_event_by_event_id( - event, attr_name, event_id_field - ): - transition_info = decision_state_transition_map.get(attr_name) - if transition_info and transition_info["type"] == "canceled": - transition = transition_info["transition"] - if transition.action: - transition.action(self, event) - else: - self._default_cancel_action(event) - break - - def _handle_initiation_failed_event(self, event: history.HistoryEvent) -> None: - """Handle initiation failed events using the global transition map.""" - # Default implementation - subclasses can override - pass - - def collect_pending_decisions(self) -> List[decision.Decision]: - """Return any decisions that should be emitted now. - - Implementations must be idempotent: repeated calls without intervening - state changes should return the same results (typically empty if already - emitted for current state). - """ - raise NotImplementedError - - -# Activity - - -@dataclass -class ActivityDecisionMachine(BaseDecisionStateMachine): - """Tracks lifecycle of a single activity execution by activity_id.""" - - activity_id: str - schedule_attributes: decision.ScheduleActivityTaskDecisionAttributes - status: DecisionState = DecisionState.CREATED - scheduled_event_id: Optional[int] = None - started_event_id: Optional[int] = None - _schedule_emitted: bool = False - _cancel_requested: bool = False - _cancel_emitted: bool = False - - def get_id(self) -> str: - return self.activity_id - - # Implement abstract methods for generic handlers - def _get_initiated_event_attr_name(self) -> str: - return "activity_task_scheduled_event_attributes" - - def _get_started_event_attr_name(self) -> str: - return "activity_task_started_event_attributes" - - def _get_completion_event_attr_names(self) -> List[str]: - return [ - "activity_task_completed_event_attributes", - "activity_task_failed_event_attributes", - "activity_task_timed_out_event_attributes", - ] - - def _get_cancel_initiated_event_attr_name(self) -> str: - return "activity_task_cancel_requested_event_attributes" - - def _get_cancel_failed_event_attr_name(self) -> str: - return "request_cancel_activity_task_failed_event_attributes" - - def _get_canceled_event_attr_names(self) -> List[str]: - return ["activity_task_canceled_event_attributes"] - - def _get_id_field_name(self) -> str: - return "activity_id" - - def _get_event_id_field_name(self) -> str: - return "scheduled_event_id" - - def collect_pending_decisions(self) -> List[decision.Decision]: - decisions: List[decision.Decision] = [] - - if self.status is DecisionState.CREATED and not self._schedule_emitted: - # Emit initial schedule decision - decisions.append( - decision.Decision( - schedule_activity_task_decision_attributes=self.schedule_attributes - ) - ) - self._schedule_emitted = True - - if ( - self._cancel_requested - and not self._cancel_emitted - and not self.is_terminal() - ): - # Emit cancel request - decisions.append( - decision.Decision( - request_cancel_activity_task_decision_attributes=decision.RequestCancelActivityTaskDecisionAttributes( - activity_id=self.activity_id - ) - ) - ) - self._cancel_emitted = True - - return decisions - - def request_cancel(self) -> None: - if not self.is_terminal(): - self._cancel_requested = True - - def is_terminal(self) -> bool: - return self.status in ( - DecisionState.COMPLETED, - DecisionState.CANCELED_AFTER_INITIATED, - DecisionState.CANCELED_AFTER_STARTED, - DecisionState.COMPLETED_AFTER_CANCELLATION_DECISION_SENT, - ) - - -# Timer - - -@dataclass -class TimerDecisionMachine(BaseDecisionStateMachine): - """Tracks lifecycle of a single workflow timer by timer_id.""" - - timer_id: str - start_attributes: decision.StartTimerDecisionAttributes - status: DecisionState = DecisionState.CREATED - started_event_id: Optional[int] = None - _start_emitted: bool = False - _cancel_requested: bool = False - _cancel_emitted: bool = False - - def get_id(self) -> str: - return self.timer_id - - # Implement abstract methods for generic handlers - def _get_initiated_event_attr_name(self) -> str: - return "timer_started_event_attributes" - - def _get_started_event_attr_name(self) -> str: - return "" # Timers don't have a separate started event - - def _get_completion_event_attr_names(self) -> List[str]: - return ["timer_fired_event_attributes"] - - def _get_cancel_initiated_event_attr_name(self) -> str: - return "" # Timers don't have cancel initiated events - - def _get_cancel_failed_event_attr_name(self) -> str: - return "cancel_timer_failed_event_attributes" - - def _get_canceled_event_attr_names(self) -> List[str]: - return ["timer_canceled_event_attributes"] - - def _get_id_field_name(self) -> str: - return "timer_id" - - def _get_event_id_field_name(self) -> str: - return "started_event_id" - - def collect_pending_decisions(self) -> List[decision.Decision]: - decisions: List[decision.Decision] = [] - - if self.status is DecisionState.CREATED and not self._start_emitted: - decisions.append( - decision.Decision(start_timer_decision_attributes=self.start_attributes) - ) - self._start_emitted = True - - if ( - self._cancel_requested - and not self._cancel_emitted - and not self.is_terminal() - ): - decisions.append( - decision.Decision( - cancel_timer_decision_attributes=decision.CancelTimerDecisionAttributes( - timer_id=self.timer_id - ) - ) - ) - self._cancel_emitted = True - - return decisions - - def request_cancel(self) -> None: - if not self.is_terminal(): - self._cancel_requested = True - - def is_terminal(self) -> bool: - return self.status in ( - DecisionState.COMPLETED, - DecisionState.CANCELED_AFTER_INITIATED, - DecisionState.CANCELED_AFTER_STARTED, - DecisionState.COMPLETED_AFTER_CANCELLATION_DECISION_SENT, - ) - - -# Child Workflow - - -@dataclass -class ChildWorkflowDecisionMachine(BaseDecisionStateMachine): - """Tracks lifecycle of a child workflow start/cancel by client-provided id. - - Cadence history references child workflows via initiated event IDs. For simplicity, - we track by a client-provided identifier (e.g., a unique string) that must map - to attributes.worklow_id when possible. - """ - - client_id: str - start_attributes: decision.StartChildWorkflowExecutionDecisionAttributes - status: DecisionState = DecisionState.CREATED - initiated_event_id: Optional[int] = None - started_event_id: Optional[int] = None - _start_emitted: bool = False - _cancel_requested: bool = False - _cancel_emitted: bool = False - - def get_id(self) -> str: - return self.client_id - - # Implement abstract methods for generic handlers - def _get_initiated_event_attr_name(self) -> str: - return "start_child_workflow_execution_initiated_event_attributes" - - def _get_started_event_attr_name(self) -> str: - return "child_workflow_execution_started_event_attributes" - - def _get_completion_event_attr_names(self) -> List[str]: - return [ - "child_workflow_execution_completed_event_attributes", - "child_workflow_execution_failed_event_attributes", - "child_workflow_execution_timed_out_event_attributes", - ] - - def _get_cancel_initiated_event_attr_name(self) -> str: - return "" # Child workflows don't have cancel initiated events - - def _get_cancel_failed_event_attr_name(self) -> str: - return "" # Child workflows don't have cancel failed events - - def _get_canceled_event_attr_names(self) -> List[str]: - return [ - "child_workflow_execution_canceled_event_attributes", - "child_workflow_execution_terminated_event_attributes", - ] - - def _get_id_field_name(self) -> str: - return "workflow_id" - - def _get_event_id_field_name(self) -> str: - return "initiated_event_id" - - # Override the generic ID check for child workflows since we need to check workflow_id - def _should_handle_event( - self, event: history.HistoryEvent, attr_name: str, id_field: str - ) -> bool: - """Override for child workflows to check workflow_id instead of client_id.""" - attr = getattr(event, attr_name, None) - if attr is None: - return False - - # For child workflows, check if the workflow_id matches - event_workflow_id = getattr(attr, id_field, None) - machine_workflow_id = self.start_attributes.workflow_id - return event_workflow_id == machine_workflow_id - - def collect_pending_decisions(self) -> List[decision.Decision]: - decisions: List[decision.Decision] = [] - - if self.status is DecisionState.CREATED and not self._start_emitted: - decisions.append( - decision.Decision( - start_child_workflow_execution_decision_attributes=self.start_attributes - ) - ) - self._start_emitted = True - - if ( - self._cancel_requested - and not self._cancel_emitted - and not self.is_terminal() - ): - # Request cancel of the child workflow via external cancel with child_workflow_only - decisions.append( - decision.Decision( - request_cancel_external_workflow_execution_decision_attributes=decision.RequestCancelExternalWorkflowExecutionDecisionAttributes( - domain=self.start_attributes.domain, - workflow_execution=common.WorkflowExecution( - workflow_id=self.start_attributes.workflow_id - ), - child_workflow_only=True, - ) - ) - ) - self._cancel_emitted = True - return decisions - - def request_cancel(self) -> None: - if not self.is_terminal(): - self._cancel_requested = True - - def is_terminal(self) -> bool: - return self.status in ( - DecisionState.COMPLETED, - DecisionState.CANCELED_AFTER_INITIATED, - DecisionState.CANCELED_AFTER_STARTED, - DecisionState.COMPLETED_AFTER_CANCELLATION_DECISION_SENT, - ) - - -@dataclass -class DecisionManager: - """Aggregates multiple decision state machines and coordinates decisions. - - Typical flow per decision task: - - Instantiate/update state machines based on application intent and incoming history - - Call collect_pending_decisions() to build the decisions list - - Submit via RespondDecisionTaskCompleted - """ - - activities: Dict[str, ActivityDecisionMachine] = field(default_factory=dict) - timers: Dict[str, TimerDecisionMachine] = field(default_factory=dict) - children: Dict[str, ChildWorkflowDecisionMachine] = field(default_factory=dict) - - # ----- Activity API ----- - - def schedule_activity( - self, activity_id: str, attrs: decision.ScheduleActivityTaskDecisionAttributes - ) -> ActivityDecisionMachine: - machine = self.activities.get(activity_id) - if machine is None or machine.is_terminal(): - machine = ActivityDecisionMachine( - activity_id=activity_id, schedule_attributes=attrs - ) - self.activities[activity_id] = machine - return machine - - def request_cancel_activity(self, activity_id: str) -> None: - machine = self.activities.get(activity_id) - if machine is not None: - machine.request_cancel() - - # ----- Timer API ----- - - def start_timer( - self, timer_id: str, attrs: decision.StartTimerDecisionAttributes - ) -> TimerDecisionMachine: - machine = self.timers.get(timer_id) - if machine is None or machine.is_terminal(): - machine = TimerDecisionMachine(timer_id=timer_id, start_attributes=attrs) - self.timers[timer_id] = machine - return machine - - def cancel_timer(self, timer_id: str) -> None: - machine = self.timers.get(timer_id) - if machine is not None: - machine.request_cancel() - - # ----- Child Workflow API ----- - - def start_child_workflow( - self, - client_id: str, - attrs: decision.StartChildWorkflowExecutionDecisionAttributes, - ) -> ChildWorkflowDecisionMachine: - machine = self.children.get(client_id) - if machine is None or machine.is_terminal(): - machine = ChildWorkflowDecisionMachine( - client_id=client_id, start_attributes=attrs - ) - self.children[client_id] = machine - return machine - - def cancel_child_workflow(self, client_id: str) -> None: - machine = self.children.get(client_id) - if machine is not None: - machine.request_cancel() - - # ----- History routing ----- - - def handle_history_event(self, event: history.HistoryEvent) -> None: - """Dispatch history event to typed handlers using the global transition map.""" - attr = event.WhichOneof("attributes") - - # Look up the event type from the global transition map - transition_info = decision_state_transition_map.get(attr) - if transition_info: - event_type = transition_info["type"] - # Route to all relevant machines using the new unified handle_event method - for activity_machine in list(self.activities.values()): - activity_machine.handle_event(event, event_type) - for timer_machine in list(self.timers.values()): - timer_machine.handle_event(event, event_type) - for child_machine in list(self.children.values()): - child_machine.handle_event(event, event_type) - - # ----- Decision aggregation ----- - - def collect_pending_decisions(self) -> List[decision.Decision]: - decisions: List[decision.Decision] = [] - - # Activities - for machine in list(self.activities.values()): - decisions.extend(machine.collect_pending_decisions()) - - # Timers - for timer_machine in list(self.timers.values()): - decisions.extend(timer_machine.collect_pending_decisions()) - - # Children - for child_machine in list(self.children.values()): - decisions.extend(child_machine.collect_pending_decisions()) - - return decisions diff --git a/cadence/_internal/rpc/error.py b/cadence/_internal/rpc/error.py index c8cf3ea..708bd36 100644 --- a/cadence/_internal/rpc/error.py +++ b/cadence/_internal/rpc/error.py @@ -78,10 +78,10 @@ async def intercept_unary_unary( return CadenceErrorUnaryUnaryCall(rpc_call) -def map_error(e: AioRpcError) -> error.CadenceError: +def map_error(e: AioRpcError) -> error.CadenceRpcError: status: Status | None = from_call(e) if not status or not status.details: - return error.CadenceError(e.details(), e.code()) + return error.CadenceRpcError(e.details(), e.code()) details = status.details[0] if details.Is(error_pb2.WorkflowExecutionAlreadyStartedError.DESCRIPTOR): @@ -145,4 +145,4 @@ def map_error(e: AioRpcError) -> error.CadenceError: details.Unpack(service_busy) return error.ServiceBusyError(e.details(), e.code(), service_busy.reason) else: - return error.CadenceError(e.details(), e.code()) + return error.CadenceRpcError(e.details(), e.code()) diff --git a/cadence/_internal/rpc/retry.py b/cadence/_internal/rpc/retry.py index dd3fd35..9947691 100644 --- a/cadence/_internal/rpc/retry.py +++ b/cadence/_internal/rpc/retry.py @@ -5,7 +5,7 @@ from grpc import StatusCode from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails -from cadence.error import CadenceError, EntityNotExistsError +from cadence.error import CadenceRpcError, EntityNotExistsError RETRYABLE_CODES = { StatusCode.INTERNAL, @@ -73,7 +73,7 @@ async def intercept_unary_unary( try: # Return the result directly if success. GRPC will wrap it back into a UnaryUnaryCall return await rpc_call - except CadenceError as e: + except CadenceRpcError as e: err = e attempts += 1 @@ -90,7 +90,7 @@ async def intercept_unary_unary( return rpc_call -def is_retryable(err: CadenceError, call_details: ClientCallDetails) -> bool: +def is_retryable(err: CadenceRpcError, call_details: ClientCallDetails) -> bool: # Handle requests to the passive side, matching the Go and Java Clients if call_details.method == GET_WORKFLOW_HISTORY and isinstance( err, EntityNotExistsError diff --git a/cadence/_internal/workflow/__init__.py b/cadence/_internal/workflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cadence/_internal/workflow/context.py b/cadence/_internal/workflow/context.py index e2516ab..d008ce4 100644 --- a/cadence/_internal/workflow/context.py +++ b/cadence/_internal/workflow/context.py @@ -1,14 +1,31 @@ -from typing import Optional +from datetime import timedelta +from math import ceil +from typing import Optional, Any, Unpack, Type, cast + +from cadence._internal.workflow.statemachine.decision_manager import DecisionManager +from cadence._internal.workflow.decisions_helper import DecisionsHelper +from cadence.api.v1.common_pb2 import ActivityType +from cadence.api.v1.decision_pb2 import ScheduleActivityTaskDecisionAttributes +from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind from cadence.client import Client -from cadence.workflow import WorkflowContext, WorkflowInfo +from cadence.data_converter import DataConverter +from cadence.workflow import WorkflowContext, WorkflowInfo, ResultType, ActivityOptions class Context(WorkflowContext): - def __init__(self, client: Client, info: WorkflowInfo): + def __init__( + self, + client: Client, + info: WorkflowInfo, + decision_helper: DecisionsHelper, + decision_manager: DecisionManager, + ): self._client = client self._info = info self._replay_mode = True self._replay_current_time_milliseconds: Optional[int] = None + self._decision_helper = decision_helper + self._decision_manager = decision_manager def info(self) -> WorkflowInfo: return self._info @@ -16,6 +33,73 @@ def info(self) -> WorkflowInfo: def client(self) -> Client: return self._client + def data_converter(self) -> DataConverter: + return self._client.data_converter + + async def execute_activity( + self, + activity: str, + result_type: Type[ResultType], + *args: Any, + **kwargs: Unpack[ActivityOptions], + ) -> ResultType: + opts = ActivityOptions(**kwargs) + if "schedule_to_close_timeout" not in opts and ( + "schedule_to_start_timeout" not in opts + or "start_to_close_timeout" not in opts + ): + raise ValueError( + "Either schedule_to_close_timeout or both schedule_to_start_timeout and start_to_close_timeout must be specified" + ) + + schedule_to_close = opts.get("schedule_to_close_timeout", None) + schedule_to_start = opts.get("schedule_to_start_timeout", None) + start_to_close = opts.get("start_to_close_timeout", None) + heartbeat = opts.get("heartbeat_timeout", None) + + if schedule_to_close is None: + schedule_to_close = schedule_to_start + start_to_close # type: ignore + + if start_to_close is None: + start_to_close = schedule_to_close + + if schedule_to_start is None: + schedule_to_start = schedule_to_close + + if heartbeat is None: + heartbeat = schedule_to_close + + task_list = ( + opts["task_list"] + if opts.get("task_list", None) + else self._info.workflow_task_list + ) + + activity_input = self.data_converter().to_data(list(args)) + activity_id = self._decision_helper.generate_activity_id(activity) + schedule_attributes = ScheduleActivityTaskDecisionAttributes( + activity_id=activity_id, + activity_type=ActivityType(name=activity), + domain=self._client.domain, + task_list=TaskList(kind=TaskListKind.TASK_LIST_KIND_NORMAL, name=task_list), + input=activity_input, + schedule_to_close_timeout=_round_to_nearest_second(schedule_to_close), + schedule_to_start_timeout=_round_to_nearest_second(schedule_to_start), + start_to_close_timeout=_round_to_nearest_second(start_to_close), + heartbeat_timeout=_round_to_nearest_second(heartbeat), + retry_policy=None, + header=None, + request_local_dispatch=False, + ) + + result_payload = await self._decision_manager.schedule_activity( + schedule_attributes + ) + + result = self.data_converter().from_data(result_payload, [result_type])[0] + + return cast(ResultType, result) + def set_replay_mode(self, replay: bool) -> None: """Set whether the workflow is currently in replay mode.""" self._replay_mode = replay @@ -31,3 +115,7 @@ def set_replay_current_time_milliseconds(self, time_millis: int) -> None: def get_replay_current_time_milliseconds(self) -> Optional[int]: """Get the current replay time in milliseconds.""" return self._replay_current_time_milliseconds + + +def _round_to_nearest_second(delta: timedelta) -> timedelta: + return timedelta(seconds=ceil(delta.total_seconds())) diff --git a/cadence/_internal/workflow/decisions_helper.py b/cadence/_internal/workflow/decisions_helper.py index d92fb73..65bcb86 100644 --- a/cadence/_internal/workflow/decisions_helper.py +++ b/cadence/_internal/workflow/decisions_helper.py @@ -9,10 +9,9 @@ from dataclasses import dataclass from typing import Dict, Optional -from cadence._internal.decision_state_machine import ( +from cadence._internal.workflow.statemachine.decision_state_machine import ( DecisionId, DecisionType, - DecisionManager, ) logger = logging.getLogger(__name__) @@ -37,17 +36,13 @@ class DecisionsHelper: state machines for proper decision lifecycle tracking. """ - def __init__(self, decision_manager: DecisionManager): + def __init__(self): """ Initialize the DecisionsHelper with a DecisionManager reference. - - Args: - decision_manager: The DecisionManager containing the state machines """ self._next_decision_counters: Dict[DecisionType, int] = {} self._tracked_decisions: Dict[str, DecisionTracker] = {} self._decision_id_to_key: Dict[str, str] = {} - self._decision_manager = decision_manager logger.debug("DecisionsHelper initialized with DecisionManager integration") def _get_next_counter(self, decision_type: DecisionType) -> int: diff --git a/cadence/_internal/workflow/statemachine/__init__.py b/cadence/_internal/workflow/statemachine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cadence/_internal/workflow/statemachine/activity_state_machine.py b/cadence/_internal/workflow/statemachine/activity_state_machine.py new file mode 100644 index 0000000..2cd6e72 --- /dev/null +++ b/cadence/_internal/workflow/statemachine/activity_state_machine.py @@ -0,0 +1,106 @@ +from cadence._internal.workflow.statemachine.decision_state_machine import ( + DecisionState, + DecisionType, + DecisionId, + DecisionFuture, + BaseDecisionStateMachine, +) +from cadence._internal.workflow.statemachine.event_dispatcher import EventDispatcher +from cadence.api.v1 import decision, history +from cadence.api.v1.common_pb2 import Payload +from cadence.error import ActivityFailure + +activity_events = EventDispatcher("scheduled_event_id") + + +class ActivityStateMachine(BaseDecisionStateMachine): + request: decision.ScheduleActivityTaskDecisionAttributes + completed: DecisionFuture[Payload] + + def __init__( + self, + request: decision.ScheduleActivityTaskDecisionAttributes, + completed: DecisionFuture[Payload], + ) -> None: + super().__init__() + self.request = request + self.completed = completed + + def get_id(self) -> DecisionId: + return DecisionId(DecisionType.ACTIVITY, self.request.activity_id) + + def get_decision(self) -> decision.Decision | None: + if self.state is DecisionState.CREATED: + return decision.Decision( + schedule_activity_task_decision_attributes=self.request + ) + + if self.state is DecisionState.CANCELED_AFTER_INITIATED: + return decision.Decision( + request_cancel_activity_task_decision_attributes=decision.RequestCancelActivityTaskDecisionAttributes( + activity_id=self.request.activity_id, + ) + ) + + return None + + def request_cancel(self) -> bool: + if self.state is DecisionState.CREATED: + self._transition(DecisionState.COMPLETED) + self.completed.force_cancel() + return True + + if self.state is DecisionState.INITIATED: + self._transition(DecisionState.CANCELED_AFTER_INITIATED) + return True + + return False + + @activity_events.event(id_attr="activity_id", event_id_is_alias=True) + def handle_scheduled(self, _: history.ActivityTaskScheduledEventAttributes) -> None: + self._transition(DecisionState.INITIATED) + + @activity_events.event() + def handle_started(self, _: history.ActivityTaskStartedEventAttributes) -> None: + # Started doesn't actually do anything in the Go client. + # The workflow can't observe it, and it doesn't impact cancellation + # self._transition(DecisionState.STARTED) + pass + + @activity_events.event() + def handle_completed( + self, event: history.ActivityTaskCompletedEventAttributes + ) -> None: + self._transition(DecisionState.COMPLETED) + self.completed.set_result(event.result) + + @activity_events.event() + def handle_failed(self, event: history.ActivityTaskFailedEventAttributes) -> None: + self._transition(DecisionState.COMPLETED) + self.completed.set_exception(ActivityFailure(event.failure.reason)) + + @activity_events.event() + def handle_timeout( + self, event: history.ActivityTaskTimedOutEventAttributes + ) -> None: + self._transition(DecisionState.COMPLETED) + self.completed.set_exception(ActivityFailure(event.details.data.decode())) + + @activity_events.event() + def handle_canceled( + self, event: history.ActivityTaskCanceledEventAttributes + ) -> None: + self._transition(DecisionState.COMPLETED) + self.completed.force_cancel(event.details.data.decode()) + + @activity_events.event("activity_id") + def handle_cancel_requested( + self, _: history.ActivityTaskCancelRequestedEventAttributes + ) -> None: + self._transition(DecisionState.CANCELLATION_DECISION_SENT) + + @activity_events.event("activity_id") + def handle_cancel_failed( + self, _: history.RequestCancelActivityTaskFailedEventAttributes + ) -> None: + self._transition(DecisionState.INITIATED) diff --git a/cadence/_internal/workflow/statemachine/decision_manager.py b/cadence/_internal/workflow/statemachine/decision_manager.py new file mode 100644 index 0000000..d811ee2 --- /dev/null +++ b/cadence/_internal/workflow/statemachine/decision_manager.py @@ -0,0 +1,157 @@ +import asyncio +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Dict, Type, Tuple, ClassVar, List + +from cadence._internal.workflow.statemachine.activity_state_machine import ( + activity_events, + ActivityStateMachine, +) +from cadence._internal.workflow.statemachine.decision_state_machine import ( + DecisionId, + DecisionStateMachine, + DecisionType, + DecisionFuture, +) +from cadence._internal.workflow.statemachine.event_dispatcher import ( + EventDispatcher, + Action, +) +from cadence._internal.workflow.statemachine.timer_state_machine import ( + TimerStateMachine, + timer_events, +) +from cadence.api.v1 import decision, history +from cadence.api.v1.common_pb2 import Payload + +DecisionAlias = Tuple[DecisionType, str | int] + + +@dataclass(frozen=True) +class EventDispatch: + decision_type: DecisionType + action: Action + + +def _create_dispatch_map( + dispatchers: dict[DecisionType, EventDispatcher], +) -> dict[Type, EventDispatch]: + result: dict[Type, EventDispatch] = {} + for decision_type, dispatcher in dispatchers.items(): + for event_type, action in dispatcher.handlers.items(): + if event_type in result: + raise ValueError( + f"Received duplicate registration for {event_type}: {decision_type} and {result[event_type].decision_type}" + ) + result[event_type] = EventDispatch(decision_type, action) + + return result + + +@dataclass +class DecisionManager: + """Aggregates multiple decision state machines and coordinates decisions. + + Typical flow per decision task: + - Instantiate/update state machines based on application intent and incoming history + - Call collect_pending_decisions() to build the decisions list + - Submit via RespondDecisionTaskCompleted + """ + + type_to_action: ClassVar[Dict[Type, EventDispatch]] = _create_dispatch_map( + { + DecisionType.ACTIVITY: activity_events, + DecisionType.TIMER: timer_events, + } + ) + state_machines: OrderedDict[DecisionId, DecisionStateMachine] = field( + default_factory=OrderedDict + ) + aliases: Dict[DecisionAlias, DecisionStateMachine] = field(default_factory=dict) + + # ----- Activity API ----- + + def schedule_activity( + self, attrs: decision.ScheduleActivityTaskDecisionAttributes + ) -> asyncio.Future[Payload]: + decision_id = DecisionId(DecisionType.ACTIVITY, attrs.activity_id) + future = DecisionFuture[Payload](lambda: self._request_cancel(decision_id)) + machine = ActivityStateMachine(attrs, future) + self._add_state_machine(machine) + + return future + + # ----- Timer API ----- + + def start_timer( + self, attrs: decision.StartTimerDecisionAttributes + ) -> asyncio.Future[None]: + decision_id = DecisionId(DecisionType.TIMER, attrs.timer_id) + future = DecisionFuture[None](lambda: self._request_cancel(decision_id)) + machine = TimerStateMachine(attrs, future) + self._add_state_machine(machine) + + return future + + def _get_machine(self, decision_id: DecisionId) -> DecisionStateMachine: + machine = self.state_machines.get(decision_id, None) + if machine is None: + raise ValueError(f"Unknown state machine: {decision_id}") + return machine + + def _add_state_machine(self, state: DecisionStateMachine) -> None: + decision_id = state.get_id() + if decision_id in self.state_machines: + raise ValueError(f"Received duplicate decision: {decision_id}") + self.state_machines[decision_id] = state + self.aliases[(decision_id.decision_type, decision_id.id)] = state + + # ----- History routing ----- + + def handle_history_event(self, event: history.HistoryEvent) -> None: + """Dispatch history event to typed handlers using the global transition map.""" + attr = event.WhichOneof("attributes") + # Based on the type of the event, determine what DecisionType it's referencing and + # the correct action to take + event_attributes = getattr(event, attr) + event_action = DecisionManager.type_to_action.get( + event_attributes.__class__, None + ) + if event_action is not None: + decision_type = event_action.decision_type + action = event_action.action + # Find what state machine the event references. + # This may be a reference via the user id or a reference to a previous event + id_for_event = getattr(event_attributes, action.id_attr) + alias = (decision_type, id_for_event) + machine = self.aliases.get(alias, None) + if machine is None: + raise KeyError( + f"Event {event.event_id} references unknown state machine {alias}" + ) + + action.fn(machine, event_attributes) + + # Certain events (scheduled) are often referenced by subsequent events + # rather than using the client provided id + if action.event_id_is_alias: + self.aliases[(decision_type, event.event_id)] = machine + + # ----- Decision aggregation ----- + + def collect_pending_decisions(self) -> List[decision.Decision]: + decisions: List[decision.Decision] = [] + + for machine in self.state_machines.values(): + to_send = machine.get_decision() + if to_send is not None: + decisions.append(to_send) + + return decisions + + def _request_cancel(self, decision_id: DecisionId) -> bool: + machine = self._get_machine(decision_id) + # Interactions with the state machines should move them to the end so that the decisions are ordered as they + # happened in the Workflow + self.state_machines.move_to_end(decision_id) + return machine.request_cancel() diff --git a/cadence/_internal/workflow/statemachine/decision_state_machine.py b/cadence/_internal/workflow/statemachine/decision_state_machine.py new file mode 100644 index 0000000..4bcc2c9 --- /dev/null +++ b/cadence/_internal/workflow/statemachine/decision_state_machine.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Protocol, TypeVar, Optional + +from cadence.api.v1 import ( + decision_pb2 as decision, +) + + +# TODO: Remove unused states +class DecisionState(Enum): + """Lifecycle states for a decision-producing state machine instance.""" + + CREATED = 0 + DECISION_SENT = 1 + CANCELED_BEFORE_INITIATED = 2 + INITIATED = 3 + STARTED = 4 + CANCELED_AFTER_INITIATED = 5 + CANCELED_AFTER_STARTED = 6 + CANCELLATION_DECISION_SENT = 7 + COMPLETED_AFTER_CANCELLATION_DECISION_SENT = 8 + COMPLETED = 9 + + +class DecisionType(Enum): + """Types of decisions that can be made by state machines.""" + + ACTIVITY = 0 + CHILD_WORKFLOW = 1 + CANCELLATION = 2 + MARKER = 3 + TIMER = 4 + SIGNAL = 5 + UPSERT_SEARCH_ATTRIBUTES = 6 + + +@dataclass(frozen=True) +class DecisionId: + decision_type: DecisionType + id: str + + +class DecisionStateMachine(Protocol): + def get_id(self) -> DecisionId: ... + + def get_decision(self) -> decision.Decision | None: ... + + def request_cancel(self) -> bool: ... + + +class BaseDecisionStateMachine(DecisionStateMachine): + def __init__(self): + self._state = DecisionState.CREATED + + def _transition( + self, to: DecisionState, allowed_from: list[DecisionState] | None = None + ) -> None: + # TODO: Maybe track previous states like the other clients + if allowed_from and self.state not in allowed_from: + raise RuntimeError(f"unable to transition to {to} from {self.state}") + self._state = to + + @property + def state(self) -> DecisionState: + return self._state + + +T = TypeVar("T") +CancelFn = Callable[[], bool] + + +class DecisionFuture(asyncio.Future[T]): + def __init__(self, request_cancel: CancelFn | None = None) -> None: + super().__init__() + if request_cancel is None: + request_cancel = self.force_cancel + self._request_cancel = request_cancel + + def force_cancel(self, message: Optional[str] = None) -> bool: + return super().cancel(message) + + def cancel(self, msg=None) -> bool: + return self._request_cancel() diff --git a/cadence/_internal/workflow/statemachine/event_dispatcher.py b/cadence/_internal/workflow/statemachine/event_dispatcher.py new file mode 100644 index 0000000..94d2977 --- /dev/null +++ b/cadence/_internal/workflow/statemachine/event_dispatcher.py @@ -0,0 +1,76 @@ +from dataclasses import dataclass +from inspect import signature +from typing import Type, Callable, get_type_hints, TypeVar, Any, cast + +from google.protobuf.message import Message + +T = TypeVar("T") + +EventHandler = Callable[[Any, T], None] + + +@dataclass( + frozen=True, +) +class Action: + fn: EventHandler + id_attr: str + event_id_is_alias: bool + + +class EventDispatcher: + handlers: dict[Type, Action] + + def __init__(self, default_id_attr: str) -> None: + self._default_id_attr = default_id_attr + self.handlers = {} + + def event( + self, id_attr: str = "", event_id_is_alias: bool = False + ) -> Callable[[EventHandler], EventHandler]: + def decorator(func: EventHandler) -> EventHandler: + event_type = _find_event_type(func) + event_id_attr = id_attr if id_attr else self._default_id_attr + + _validate_field(func, event_type, event_id_attr) + if event_type in self.handlers: + raise ValueError( + f"Duplicate handler for {event_type}: {func.__qualname__} and {self.handlers[event_type].fn.__qualname__}" + ) + self.handlers[event_type] = Action(func, event_id_attr, event_id_is_alias) + return func + + return decorator + + +def _find_event_type(func: EventHandler) -> Type[Message]: + sig = signature(func) + type_hints = get_type_hints(func) + if len(sig.parameters) != 2: + raise ValueError( + f"Expected 2 arguments (self, event), {func.__qualname__} has: {sig.parameters}" + ) + (non_self_param, _) = list(sig.parameters.items())[1] + if non_self_param not in type_hints: + raise ValueError(f"Missing type hint on {func.__qualname__}: {non_self_param}") + if "return" in type_hints and type_hints["return"] != None.__class__: + raise ValueError( + f"Event methods must return None, {func.__qualname__} returns: {type_hints['return']}" + ) + + event_type = type_hints[non_self_param] + if not issubclass(event_type, Message): + raise ValueError( + f"Event methods must accept a Message, {func.__qualname__} accepts: {event_type}" + ) + + # Mypy struggles without this for some reason, despite type narrowing being supported + return cast(Type[Message], event_type) + + +def _validate_field(func: EventHandler, event_type: Type[Message], field: str) -> None: + fields = event_type.DESCRIPTOR.fields_by_name + if field not in fields: + raise ValueError( + f"{func.__qualname__} handles {event_type.__qualname__}, which has no field {field}" + ) diff --git a/cadence/_internal/workflow/statemachine/timer_state_machine.py b/cadence/_internal/workflow/statemachine/timer_state_machine.py new file mode 100644 index 0000000..77d5a80 --- /dev/null +++ b/cadence/_internal/workflow/statemachine/timer_state_machine.py @@ -0,0 +1,73 @@ +from cadence._internal.workflow.statemachine.decision_state_machine import ( + DecisionState, + DecisionFuture, + DecisionType, + DecisionId, + BaseDecisionStateMachine, +) +from cadence._internal.workflow.statemachine.event_dispatcher import EventDispatcher +from cadence.api.v1 import decision, history + +timer_events = EventDispatcher("timer_id") + + +class TimerStateMachine(BaseDecisionStateMachine): + request: decision.StartTimerDecisionAttributes + completed: DecisionFuture[None] + + def __init__( + self, + request: decision.StartTimerDecisionAttributes, + completed: DecisionFuture[None], + ) -> None: + super().__init__() + self.request = request + self.completed = completed + + def get_id(self) -> DecisionId: + return DecisionId(DecisionType.TIMER, self.request.timer_id) + + def get_decision(self) -> decision.Decision | None: + if self.state is DecisionState.CREATED: + return decision.Decision(start_timer_decision_attributes=self.request) + if self.state is DecisionState.CANCELED_AFTER_INITIATED: + return decision.Decision( + cancel_timer_decision_attributes=decision.CancelTimerDecisionAttributes( + timer_id=self.request.timer_id, + ) + ) + return None + + def request_cancel(self) -> bool: + if self.state is DecisionState.CREATED: + self._transition(DecisionState.COMPLETED) + self.completed.force_cancel() + return True + + if self.state is DecisionState.INITIATED: + self._transition(DecisionState.CANCELED_AFTER_INITIATED) + self.completed.force_cancel() + return True + + return False + + @timer_events.event() + def handle_started(self, _: history.TimerStartedEventAttributes) -> None: + self._transition(DecisionState.INITIATED) + + @timer_events.event() + def handle_fired(self, _: history.TimerFiredEventAttributes) -> None: + self._transition(DecisionState.COMPLETED) + self.completed.set_result(None) + + @timer_events.event() + def handle_canceled(self, _: history.TimerCanceledEventAttributes) -> None: + # Timers resolve immediately regardless of the outcome of the cancellation + self._transition(DecisionState.COMPLETED) + + @timer_events.event() + def handle_cancel_failed(self, _: history.CancelTimerFailedEventAttributes) -> None: + # This leaves the timer in a likely invalid state, but matches the other clients. + # The only way for timer cancellation to fail is if the timer ID isn't known, so this + # can't really happen in the first place. + self._transition(DecisionState.INITIATED) diff --git a/cadence/_internal/workflow/workflow_engine.py b/cadence/_internal/workflow/workflow_engine.py index 627434b..997a287 100644 --- a/cadence/_internal/workflow/workflow_engine.py +++ b/cadence/_internal/workflow/workflow_engine.py @@ -6,11 +6,11 @@ from cadence._internal.workflow.context import Context from cadence._internal.workflow.decisions_helper import DecisionsHelper from cadence._internal.workflow.decision_events_iterator import DecisionEventsIterator +from cadence._internal.workflow.statemachine.decision_manager import DecisionManager from cadence.api.v1.decision_pb2 import Decision from cadence.client import Client from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse from cadence.workflow import WorkflowInfo -from cadence._internal.decision_state_machine import DecisionManager logger = logging.getLogger(__name__) @@ -22,13 +22,15 @@ class DecisionResult: class WorkflowEngine: def __init__(self, info: WorkflowInfo, client: Client, workflow_definition=None): - self._context = Context(client, info) self._workflow_definition = workflow_definition self._workflow_instance = None if workflow_definition: self._workflow_instance = workflow_definition.cls() self._decision_manager = DecisionManager() - self._decisions_helper = DecisionsHelper(self._decision_manager) + self._decisions_helper = DecisionsHelper() + self._context = Context( + client, info, self._decisions_helper, self._decision_manager + ) self._is_workflow_complete = False async def process_decision( @@ -300,7 +302,7 @@ async def _execute_workflow_function( ) # Extract workflow input from history - workflow_input = await self._extract_workflow_input(decision_task) + workflow_input = self._extract_workflow_input(decision_task) # Execute workflow function result = await self._execute_workflow_function_once( @@ -334,7 +336,7 @@ async def _execute_workflow_function( ) raise - async def _extract_workflow_input( + def _extract_workflow_input( self, decision_task: PollForDecisionTaskResponse ) -> Any: """ @@ -359,7 +361,7 @@ async def _extract_workflow_input( try: # Use from_data method with a single type hint of None (no type conversion) input_data_list = ( - await self._context.client().data_converter.from_data( + self._context.client().data_converter.from_data( started_attrs.input, [None] ) ) diff --git a/cadence/activity.py b/cadence/activity.py index 581bea9..0394d96 100644 --- a/cadence/activity.py +++ b/cadence/activity.py @@ -19,9 +19,12 @@ get_type_hints, Any, overload, + Tuple, + Sequence, ) from cadence import Client +from cadence.workflow import WorkflowContext, ActivityOptions, execute_activity @dataclass(frozen=True) @@ -81,7 +84,8 @@ def get() -> "ActivityContext": class ActivityParameter: name: str type_hint: Type | None - default_value: Any | None + has_default: bool + default_value: Any class ExecutionStrategy(Enum): @@ -104,16 +108,43 @@ def __init__( name: str, strategy: ExecutionStrategy, params: list[ActivityParameter], + result_type: Type[T], ): self._wrapped = wrapped self._name = name self._strategy = strategy self._params = params + self._result_type = result_type + self._execution_options = ActivityOptions() update_wrapper(self, wrapped) def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + if WorkflowContext.is_set(): + # If the original function is async then this is fine + # If it's not async then this is invalid typing, but still allowed + # Users can use execute as a guaranteed type safe option if the function is sync + return self.execute(*args, **kwargs) # type: ignore return self._wrapped(*args, **kwargs) + def with_options( + self, **kwargs: Unpack[ActivityOptions] + ) -> "ActivityDefinition[P, T]": + res = ActivityDefinition( + self._wrapped, self._name, self.strategy, self.params, self.result_type + ) + new_opts = self._execution_options.copy() + new_opts.update(kwargs) + res._execution_options = new_opts + return res + + async def execute(self, *args: P.args, **kwargs: P.kwargs) -> T: + return await execute_activity( + self.name, + self.result_type, + *_to_parameters(self.params, args, kwargs), + **self._execution_options, + ) + @property def name(self) -> str: return self._name @@ -126,6 +157,10 @@ def strategy(self) -> ExecutionStrategy: def params(self) -> list[ActivityParameter]: return self._params + @property + def result_type(self) -> Type[T]: + return self._result_type + @staticmethod def wrap( fn: Callable[P, T], opts: ActivityDefinitionOptions @@ -138,8 +173,8 @@ def wrap( if inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn.__call__): # type: ignore strategy = ExecutionStrategy.ASYNC - params = _get_params(fn) - return ActivityDefinition(fn, name, strategy, params) + params, result_type = _get_signature(fn) + return ActivityDefinition(fn, name, strategy, params, result_type) ActivityDecorator = Callable[[Callable[P, T]], ActivityDefinition[P, T]] @@ -167,25 +202,54 @@ def decorator(inner_fn: Callable[P, T]) -> ActivityDefinition[P, T]: return decorator -def _get_params(fn: Callable) -> list[ActivityParameter]: - args = signature(fn).parameters +def _get_signature(fn: Callable[P, T]) -> Tuple[list[ActivityParameter], Type[T]]: + sig = signature(fn) + args = sig.parameters hints = get_type_hints(fn) - result = [] + params = [] for name, param in args.items(): - # "unbound functions" aren't a thing in the Python spec. Filter out the self parameter and hope they followed - # the convention. + # "unbound functions" aren't a thing in the Python spec. We don't have a way to determine whether the function + # is part of a class or is standalone. + # Filter out the self parameter and hope they followed the convention. if param.name == "self": continue default = None + has_default = False if param.default != Parameter.empty: default = param.default + has_default = param.default is not None if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): type_hint = hints.get(name, None) - result.append(ActivityParameter(name, type_hint, default)) - + params.append(ActivityParameter(name, type_hint, has_default, default)) else: raise ValueError( f"Parameters must be positional. {name} is {param.kind}, and not valid" ) + # Treat unspecified return type + return_type = hints.get("return", dict) + + return params, return_type + + +def _to_parameters( + params: list[ActivityParameter], args: Sequence[Any], kwargs: dict[str, Any] +) -> list[Any]: + result: list[Any] = [] + for value, param_spec in zip(args, params): + result.append(value) + + i = len(result) + while i < len(params): + param = params[i] + if param.name not in kwargs and not param.has_default: + raise ValueError(f"Missing parameter: {param.name}") + + value = kwargs.pop(param.name, param.default_value) + result.append(value) + i = i + 1 + + if len(kwargs) > 0: + raise ValueError(f"Unexpected keyword arguments: {kwargs}") + return result diff --git a/cadence/client.py b/cadence/client.py index dcad7f3..a75d7b5 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -130,7 +130,7 @@ async def __aenter__(self) -> "Client": async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.close() - async def _build_start_workflow_request( + def _build_start_workflow_request( self, workflow: Union[str, Callable], args: tuple[Any, ...], @@ -151,7 +151,7 @@ async def _build_start_workflow_request( input_payload = None if args: try: - input_payload = await self.data_converter.to_data(list(args)) + input_payload = self.data_converter.to_data(list(args)) except Exception as e: raise ValueError(f"Failed to encode workflow arguments: {e}") @@ -209,7 +209,7 @@ async def start_workflow( options = _validate_and_apply_defaults(StartWorkflowOptions(**options_kwargs)) # Build the gRPC request - request = await self._build_start_workflow_request(workflow, args, options) + request = self._build_start_workflow_request(workflow, args, options) # Execute the gRPC call try: diff --git a/cadence/data_converter.py b/cadence/data_converter.py index 16851bf..cae382d 100644 --- a/cadence/data_converter.py +++ b/cadence/data_converter.py @@ -10,13 +10,11 @@ class DataConverter(Protocol): @abstractmethod - async def from_data( - self, payload: Payload, type_hints: List[Type | None] - ) -> List[Any]: + def from_data(self, payload: Payload, type_hints: List[Type | None]) -> List[Any]: raise NotImplementedError() @abstractmethod - async def to_data(self, values: List[Any]) -> Payload: + def to_data(self, values: List[Any]) -> Payload: raise NotImplementedError() @@ -26,9 +24,7 @@ def __init__(self) -> None: # Need to use std lib decoder in order to decode the custom whitespace delimited data format self._decoder = JSONDecoder(strict=False) - async def from_data( - self, payload: Payload, type_hints: List[Type | None] - ) -> List[Any]: + def from_data(self, payload: Payload, type_hints: List[Type | None]) -> List[Any]: if not payload.data: return DefaultDataConverter._convert_into([], type_hints) @@ -72,7 +68,7 @@ def _get_default(type_hint: Type) -> Any: return False return None - async def to_data(self, values: List[Any]) -> Payload: + def to_data(self, values: List[Any]) -> Payload: result = bytearray() for index, value in enumerate(values): self._encoder.encode_into(value, result, -1) diff --git a/cadence/error.py b/cadence/error.py index 1d3be6d..3437f89 100644 --- a/cadence/error.py +++ b/cadence/error.py @@ -1,15 +1,18 @@ import grpc -class CadenceError(Exception): +class ActivityFailure(Exception): + def __init__(self, message: str) -> None: + super().__init__(message) + + +class CadenceRpcError(Exception): def __init__(self, message: str, code: grpc.StatusCode, *args): super().__init__(message, code, *args) self.code = code - pass - -class WorkflowExecutionAlreadyStartedError(CadenceError): +class WorkflowExecutionAlreadyStartedError(CadenceRpcError): def __init__( self, message: str, code: grpc.StatusCode, start_request_id: str, run_id: str ) -> None: @@ -18,7 +21,7 @@ def __init__( self.run_id = run_id -class EntityNotExistsError(CadenceError): +class EntityNotExistsError(CadenceRpcError): def __init__( self, message: str, @@ -35,11 +38,11 @@ def __init__( self.active_clusters = active_clusters -class WorkflowExecutionAlreadyCompletedError(CadenceError): +class WorkflowExecutionAlreadyCompletedError(CadenceRpcError): pass -class DomainNotActiveError(CadenceError): +class DomainNotActiveError(CadenceRpcError): def __init__( self, message: str, @@ -58,7 +61,7 @@ def __init__( self.active_clusters = active_clusters -class ClientVersionNotSupportedError(CadenceError): +class ClientVersionNotSupportedError(CadenceRpcError): def __init__( self, message: str, @@ -75,29 +78,29 @@ def __init__( self.supported_versions = supported_versions -class FeatureNotEnabledError(CadenceError): +class FeatureNotEnabledError(CadenceRpcError): def __init__(self, message: str, code: grpc.StatusCode, feature_flag: str) -> None: super().__init__(message, code, feature_flag) self.feature_flag = feature_flag -class CancellationAlreadyRequestedError(CadenceError): +class CancellationAlreadyRequestedError(CadenceRpcError): pass -class DomainAlreadyExistsError(CadenceError): +class DomainAlreadyExistsError(CadenceRpcError): pass -class LimitExceededError(CadenceError): +class LimitExceededError(CadenceRpcError): pass -class QueryFailedError(CadenceError): +class QueryFailedError(CadenceRpcError): pass -class ServiceBusyError(CadenceError): +class ServiceBusyError(CadenceRpcError): def __init__(self, message: str, code: grpc.StatusCode, reason: str) -> None: super().__init__(message, code, reason) self.reason = reason diff --git a/cadence/worker/_base_task_handler.py b/cadence/worker/_base_task_handler.py index 042f80b..b7aa108 100644 --- a/cadence/worker/_base_task_handler.py +++ b/cadence/worker/_base_task_handler.py @@ -26,7 +26,7 @@ def __init__(self, client, task_list: str, identity: str, **options): **options: Additional options for the handler """ self._client = client - self._task_list = task_list + self.task_list = task_list self._identity = identity self._options = options diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py index 0f5f780..4710bc1 100644 --- a/cadence/worker/_decision_task_handler.py +++ b/cadence/worker/_decision_task_handler.py @@ -108,6 +108,7 @@ async def _handle_task_implementation( workflow_domain=self._client.domain, workflow_id=workflow_id, workflow_run_id=run_id, + workflow_task_list=self.task_list, ) # Use thread-safe cache to get or create workflow engine diff --git a/cadence/worker/_registry.py b/cadence/worker/_registry.py index 05bff67..cfb2953 100644 --- a/cadence/worker/_registry.py +++ b/cadence/worker/_registry.py @@ -217,7 +217,7 @@ def of(*args: "Registry") -> "Registry": def _find_activity_definitions(instance: object) -> list[ActivityDefinition]: - attr_to_def = {} + attr_to_def: dict[str, ActivityDefinition] = {} for t in instance.__class__.__mro__: for attr in dir(t): if attr.startswith("_"): @@ -238,6 +238,7 @@ def _find_activity_definitions(instance: object) -> list[ActivityDefinition]: definition.name, definition.strategy, definition.params, + definition.result_type, ) ) diff --git a/cadence/workflow.py b/cadence/workflow.py index 0e346ea..913ebd1 100644 --- a/cadence/workflow.py +++ b/cadence/workflow.py @@ -2,20 +2,45 @@ from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass +from datetime import timedelta from typing import ( - Iterator, Callable, - TypeVar, - TypedDict, - Type, cast, - Any, Optional, Union, + Iterator, + TypedDict, + TypeVar, + Type, + Unpack, + Any, ) import inspect from cadence.client import Client +from cadence.data_converter import DataConverter + +ResultType = TypeVar("ResultType") + + +class ActivityOptions(TypedDict, total=False): + task_list: str + schedule_to_close_timeout: timedelta + schedule_to_start_timeout: timedelta + start_to_close_timeout: timedelta + heartbeat_timeout: timedelta + + +async def execute_activity( + activity: str, + result_type: Type[ResultType], + *args: Any, + **kwargs: Unpack[ActivityOptions], +) -> ResultType: + return await WorkflowContext.get().execute_activity( + activity, result_type, *args, **kwargs + ) + T = TypeVar("T", bound=Callable[..., Any]) @@ -143,6 +168,7 @@ class WorkflowInfo: workflow_domain: str workflow_id: str workflow_run_id: str + workflow_task_list: str class WorkflowContext(ABC): @@ -154,6 +180,18 @@ def info(self) -> WorkflowInfo: ... @abstractmethod def client(self) -> Client: ... + @abstractmethod + def data_converter(self) -> DataConverter: ... + + @abstractmethod + async def execute_activity( + self, + activity: str, + result_type: Type[ResultType], + *args: Any, + **kwargs: Unpack[ActivityOptions], + ) -> ResultType: ... + @contextmanager def _activate(self) -> Iterator[None]: token = WorkflowContext._var.set(self) @@ -166,4 +204,7 @@ def is_set() -> bool: @staticmethod def get() -> "WorkflowContext": - return WorkflowContext._var.get() + res = WorkflowContext._var.get(None) + if res is None: + raise RuntimeError("Workflow function used outside of workflow context") + return res diff --git a/tests/cadence/_internal/rpc/test_error.py b/tests/cadence/_internal/rpc/test_error.py index 2b67a7c..d3c30b9 100644 --- a/tests/cadence/_internal/rpc/test_error.py +++ b/tests/cadence/_internal/rpc/test_error.py @@ -13,7 +13,7 @@ from google.protobuf.message import Message from cadence.api.v1.service_meta_pb2 import HealthRequest, HealthResponse -from cadence.error import CadenceError +from cadence.error import CadenceRpcError class FakeService(service_meta_pb2_grpc.MetaAPIServicer): @@ -159,7 +159,7 @@ def fake_service(): code=code_pb2.PERMISSION_DENIED, message="no permission" ) ), - error.CadenceError( + error.CadenceRpcError( message="no permission", code=StatusCode.PERMISSION_DENIED ), id="unknown error type", @@ -167,7 +167,9 @@ def fake_service(): ], ) @pytest.mark.asyncio -async def test_map_error(fake_service, err: Message | Status, expected: CadenceError): +async def test_map_error( + fake_service, err: Message | Status, expected: CadenceRpcError +): async with insecure_channel( f"[::]:{fake_service.port}", interceptors=[CadenceErrorInterceptor()] ) as channel: diff --git a/tests/cadence/_internal/rpc/test_retry.py b/tests/cadence/_internal/rpc/test_retry.py index aed7f57..64249fd 100644 --- a/tests/cadence/_internal/rpc/test_retry.py +++ b/tests/cadence/_internal/rpc/test_retry.py @@ -17,7 +17,7 @@ DescribeWorkflowExecutionRequest, GetWorkflowExecutionHistoryRequest, ) -from cadence.error import CadenceError, FeatureNotEnabledError, EntityNotExistsError +from cadence.error import CadenceRpcError, FeatureNotEnabledError, EntityNotExistsError simple_policy = ExponentialRetryPolicy( initial_interval=1, backoff_coefficient=2, max_interval=10, max_attempts=6 @@ -145,7 +145,7 @@ def fake_service(): ) @pytest.mark.asyncio async def test_retryable_error( - fake_service, case: str, expected_calls: int, expected_err: Type[CadenceError] + fake_service, case: str, expected_calls: int, expected_err: Type[CadenceRpcError] ): fake_service.counter = 0 async with insecure_channel( diff --git a/tests/cadence/_internal/test_decision_state_machine.py b/tests/cadence/_internal/test_decision_state_machine.py deleted file mode 100644 index e967c2f..0000000 --- a/tests/cadence/_internal/test_decision_state_machine.py +++ /dev/null @@ -1,432 +0,0 @@ -import pytest - -from cadence.api.v1 import ( - decision_pb2 as decision, - history_pb2 as history, - common_pb2 as common, -) - -from cadence._internal.decision_state_machine import ( - ActivityDecisionMachine, - TimerDecisionMachine, - ChildWorkflowDecisionMachine, - DecisionManager, - DecisionState, -) - - -@pytest.mark.unit -def test_timer_state_machine_cancel_before_sent(): - attrs = decision.StartTimerDecisionAttributes(timer_id="t-cbs") - m = TimerDecisionMachine(timer_id="t-cbs", start_attributes=attrs) - m.request_cancel() - d = m.collect_pending_decisions() - assert len(d) == 2 - assert d[0].HasField("start_timer_decision_attributes") - assert d[1].HasField("cancel_timer_decision_attributes") - - -@pytest.mark.unit -def test_timer_state_machine_cancel_after_initiated(): - attrs = decision.StartTimerDecisionAttributes(timer_id="t-cai") - m = TimerDecisionMachine(timer_id="t-cai", start_attributes=attrs) - _ = m.collect_pending_decisions() - m.handle_event( - history.HistoryEvent( - event_id=1, - timer_started_event_attributes=history.TimerStartedEventAttributes( - timer_id="t-cai" - ), - ), - "initiated", - ) - m.request_cancel() - d = m.collect_pending_decisions() - assert len(d) == 1 and d[0].HasField("cancel_timer_decision_attributes") - - -@pytest.mark.unit -def test_timer_state_machine_completed_after_cancel(): - attrs = decision.StartTimerDecisionAttributes(timer_id="t-cac") - m = TimerDecisionMachine(timer_id="t-cac", start_attributes=attrs) - _ = m.collect_pending_decisions() - m.handle_event( - history.HistoryEvent( - event_id=2, - timer_started_event_attributes=history.TimerStartedEventAttributes( - timer_id="t-cac" - ), - ), - "initiated", - ) - m.request_cancel() - _ = m.collect_pending_decisions() - m.handle_event( - history.HistoryEvent( - event_id=3, - timer_fired_event_attributes=history.TimerFiredEventAttributes( - timer_id="t-cac", started_event_id=2 - ), - ), - "completion", - ) - assert m.status is DecisionState.COMPLETED - - -@pytest.mark.unit -def test_timer_state_machine_complete_without_cancel(): - attrs = decision.StartTimerDecisionAttributes(timer_id="t-cwc") - m = TimerDecisionMachine(timer_id="t-cwc", start_attributes=attrs) - _ = m.collect_pending_decisions() - m.handle_event( - history.HistoryEvent( - event_id=4, - timer_started_event_attributes=history.TimerStartedEventAttributes( - timer_id="t-cwc" - ), - ), - "initiated", - ) - m.handle_event( - history.HistoryEvent( - event_id=5, - timer_fired_event_attributes=history.TimerFiredEventAttributes( - timer_id="t-cwc", started_event_id=4 - ), - ), - "completion", - ) - assert m.status is DecisionState.COMPLETED - - -@pytest.mark.unit -def test_timer_cancel_event_ordering(): - attrs = decision.StartTimerDecisionAttributes(timer_id="t-ord") - m = TimerDecisionMachine(timer_id="t-ord", start_attributes=attrs) - _ = m.collect_pending_decisions() - m.handle_event( - history.HistoryEvent( - event_id=10, - timer_started_event_attributes=history.TimerStartedEventAttributes( - timer_id="t-ord" - ), - ), - "initiated", - ) - m.request_cancel() - d1 = m.collect_pending_decisions() - assert len(d1) == 1 and d1[0].HasField("cancel_timer_decision_attributes") - # Simulate cancel failed -> should retry emit - m.handle_event( - history.HistoryEvent( - event_id=11, - cancel_timer_failed_event_attributes=history.CancelTimerFailedEventAttributes( - timer_id="t-ord" - ), - ), - "cancel_failed", - ) - d2 = m.collect_pending_decisions() - assert len(d2) == 1 and d2[0].HasField("cancel_timer_decision_attributes") - - -@pytest.mark.unit -def test_activity_state_machine_complete_without_cancel(): - attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="act-1") - m = ActivityDecisionMachine(activity_id="act-1", schedule_attributes=attrs) - d = m.collect_pending_decisions() - assert len(d) == 1 and d[0].HasField("schedule_activity_task_decision_attributes") - m.handle_event( - history.HistoryEvent( - event_id=20, - activity_task_scheduled_event_attributes=history.ActivityTaskScheduledEventAttributes( - activity_id="act-1" - ), - ), - "initiated", - ) - m.handle_event( - history.HistoryEvent( - event_id=21, - activity_task_started_event_attributes=history.ActivityTaskStartedEventAttributes( - scheduled_event_id=20 - ), - ), - "started", - ) - m.handle_event( - history.HistoryEvent( - event_id=22, - activity_task_completed_event_attributes=history.ActivityTaskCompletedEventAttributes( - scheduled_event_id=20, started_event_id=21 - ), - ), - "completion", - ) - assert m.status is DecisionState.COMPLETED - - -@pytest.mark.unit -def test_activity_state_machine_cancel_before_sent(): - attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="act-cbs") - m = ActivityDecisionMachine(activity_id="act-cbs", schedule_attributes=attrs) - m.request_cancel() - d = m.collect_pending_decisions() - # Should emit schedule then cancel - assert len(d) == 2 - assert d[0].HasField("schedule_activity_task_decision_attributes") - assert d[1].HasField("request_cancel_activity_task_decision_attributes") - - -@pytest.mark.unit -def test_activity_state_machine_cancel_after_sent(): - attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="act-cas") - m = ActivityDecisionMachine(activity_id="act-cas", schedule_attributes=attrs) - _ = m.collect_pending_decisions() - m.request_cancel() - d = m.collect_pending_decisions() - assert len(d) == 1 and d[0].HasField( - "request_cancel_activity_task_decision_attributes" - ) - - -@pytest.mark.unit -def test_activity_state_machine_completed_after_cancel(): - attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="act-cac") - m = ActivityDecisionMachine(activity_id="act-cac", schedule_attributes=attrs) - _ = m.collect_pending_decisions() - m.handle_event( - history.HistoryEvent( - event_id=30, - activity_task_scheduled_event_attributes=history.ActivityTaskScheduledEventAttributes( - activity_id="act-cac" - ), - ), - "initiated", - ) - m.handle_event( - history.HistoryEvent( - event_id=31, - activity_task_started_event_attributes=history.ActivityTaskStartedEventAttributes( - scheduled_event_id=30 - ), - ), - "started", - ) - m.request_cancel() - _ = m.collect_pending_decisions() - m.handle_event( - history.HistoryEvent( - event_id=32, - activity_task_completed_event_attributes=history.ActivityTaskCompletedEventAttributes( - scheduled_event_id=30, started_event_id=31 - ), - ), - "completion", - ) - assert m.status is DecisionState.COMPLETED - - -@pytest.mark.unit -def test_child_workflow_state_machine_basic(): - attrs = decision.StartChildWorkflowExecutionDecisionAttributes( - domain="d1", workflow_id="wf-1", workflow_type=common.WorkflowType(name="t") - ) - m = ChildWorkflowDecisionMachine(client_id="cw-1", start_attributes=attrs) - d = m.collect_pending_decisions() - assert len(d) == 1 and d[0].HasField( - "start_child_workflow_execution_decision_attributes" - ) - m.handle_event( - history.HistoryEvent( - event_id=40, - start_child_workflow_execution_initiated_event_attributes=history.StartChildWorkflowExecutionInitiatedEventAttributes( - domain="d1", workflow_id="wf-1" - ), - ), - "initiated", - ) - m.handle_event( - history.HistoryEvent( - event_id=41, - child_workflow_execution_started_event_attributes=history.ChildWorkflowExecutionStartedEventAttributes( - initiated_event_id=40 - ), - ), - "started", - ) - m.handle_event( - history.HistoryEvent( - event_id=42, - child_workflow_execution_completed_event_attributes=history.ChildWorkflowExecutionCompletedEventAttributes( - initiated_event_id=40 - ), - ), - "completion", - ) - assert m.status is DecisionState.COMPLETED - - -@pytest.mark.unit -def test_child_workflow_state_machine_cancel_succeed(): - attrs = decision.StartChildWorkflowExecutionDecisionAttributes( - domain="d2", workflow_id="wf-2", workflow_type=common.WorkflowType(name="t2") - ) - m = ChildWorkflowDecisionMachine(client_id="cw-2", start_attributes=attrs) - _ = m.collect_pending_decisions() - m.handle_event( - history.HistoryEvent( - event_id=50, - start_child_workflow_execution_initiated_event_attributes=history.StartChildWorkflowExecutionInitiatedEventAttributes( - domain="d2", workflow_id="wf-2" - ), - ), - "initiated", - ) - m.request_cancel() - d = m.collect_pending_decisions() - assert len(d) == 1 and d[0].HasField( - "request_cancel_external_workflow_execution_decision_attributes" - ) - m.handle_event( - history.HistoryEvent( - event_id=51, - child_workflow_execution_canceled_event_attributes=history.ChildWorkflowExecutionCanceledEventAttributes( - initiated_event_id=50 - ), - ), - "canceled", - ) - assert m.status is DecisionState.CANCELED_AFTER_INITIATED - - -@pytest.mark.unit -@pytest.mark.skip("External cancel failure event handling is not implemented") -def test_child_workflow_state_machine_cancel_failed(): - pass - - -@pytest.mark.unit -@pytest.mark.skip("Marker decision state machine is not implemented in this module") -def test_marker_state_machine(): - pass - - -@pytest.mark.unit -@pytest.mark.skip( - "Upsert search attributes decision state machine is not implemented in this module" -) -def test_upsert_search_attributes_decision_state_machine(): - pass - - -@pytest.mark.unit -@pytest.mark.skip( - "Cancel external workflow decision state machine is not implemented in this module" -) -def test_cancel_external_workflow_state_machine_succeed(): - pass - - -@pytest.mark.unit -@pytest.mark.skip( - "Cancel external workflow decision state machine is not implemented in this module" -) -def test_cancel_external_workflow_state_machine_failed(): - pass - - -@pytest.mark.unit -def test_manager_aggregates_and_routes(): - dm = DecisionManager() - - # Create three machines via public API - a = dm.schedule_activity( - "a1", decision.ScheduleActivityTaskDecisionAttributes(activity_id="a1") - ) - t = dm.start_timer("t1", decision.StartTimerDecisionAttributes(timer_id="t1")) - c = dm.start_child_workflow( - "c1", - decision.StartChildWorkflowExecutionDecisionAttributes( - domain="d", workflow_id="w1", workflow_type=common.WorkflowType(name="t") - ), - ) - - # First collect should include 3 decisions (schedule/start/start_child) - decisions_first = dm.collect_pending_decisions() - assert len(decisions_first) == 3 - - # Idempotent on second collect - assert dm.collect_pending_decisions() == [] - - # Route initiated events - dm.handle_history_event( - history.HistoryEvent( - event_id=100, - activity_task_scheduled_event_attributes=history.ActivityTaskScheduledEventAttributes( - activity_id="a1" - ), - ) - ) - dm.handle_history_event( - history.HistoryEvent( - event_id=101, - timer_started_event_attributes=history.TimerStartedEventAttributes( - timer_id="t1" - ), - ) - ) - dm.handle_history_event( - history.HistoryEvent( - event_id=102, - start_child_workflow_execution_initiated_event_attributes=history.StartChildWorkflowExecutionInitiatedEventAttributes( - domain="d", workflow_id="w1" - ), - ) - ) - - assert a.status is DecisionState.INITIATED - assert t.status is DecisionState.INITIATED - assert c.status is DecisionState.INITIATED - - # Route started and completion events - dm.handle_history_event( - history.HistoryEvent( - event_id=103, - activity_task_started_event_attributes=history.ActivityTaskStartedEventAttributes( - scheduled_event_id=100 - ), - ) - ) - dm.handle_history_event( - history.HistoryEvent( - event_id=104, - child_workflow_execution_started_event_attributes=history.ChildWorkflowExecutionStartedEventAttributes( - initiated_event_id=102 - ), - ) - ) - dm.handle_history_event( - history.HistoryEvent( - event_id=105, - activity_task_completed_event_attributes=history.ActivityTaskCompletedEventAttributes( - scheduled_event_id=100, started_event_id=103 - ), - ) - ) - dm.handle_history_event( - history.HistoryEvent( - event_id=106, - timer_fired_event_attributes=history.TimerFiredEventAttributes( - timer_id="t1", started_event_id=101 - ), - ) - ) - dm.handle_history_event( - history.HistoryEvent( - event_id=107, - child_workflow_execution_completed_event_attributes=history.ChildWorkflowExecutionCompletedEventAttributes( - initiated_event_id=102 - ), - ) - ) diff --git a/tests/cadence/_internal/workflow/statemachine/test_activity_state_machine.py b/tests/cadence/_internal/workflow/statemachine/test_activity_state_machine.py new file mode 100644 index 0000000..4ab3b91 --- /dev/null +++ b/tests/cadence/_internal/workflow/statemachine/test_activity_state_machine.py @@ -0,0 +1,196 @@ +from asyncio import CancelledError + +import pytest + +from cadence._internal.workflow.statemachine.activity_state_machine import ( + ActivityStateMachine, +) +from cadence._internal.workflow.statemachine.decision_state_machine import ( + DecisionFuture, +) +from cadence.api.v1 import decision, history +from cadence.api.v1.common_pb2 import Payload, Failure +from cadence.api.v1.decision_pb2 import RequestCancelActivityTaskDecisionAttributes +from cadence.error import ActivityFailure + +### These tests have to be async because they rely on the presence of an eventloop + + +async def test_activity_state_machine_initiated(): + attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + completed = DecisionFuture[Payload]() + m = ActivityStateMachine(attrs, completed) + + assert completed.done() is False + assert m.get_decision() == decision.Decision( + schedule_activity_task_decision_attributes=attrs + ) + + +async def test_activity_state_machine_cancelled_before_initiated(): + attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + completed = DecisionFuture[Payload]() + + m = ActivityStateMachine(attrs, completed) + res = m.request_cancel() + + assert res is True + assert completed.done() is True + assert completed.cancelled() is True + assert m.get_decision() is None + + +async def test_activity_state_machine_cancelled_after_initiated(): + attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + completed = DecisionFuture[Payload]() + m = ActivityStateMachine(attrs, completed) + + m.handle_scheduled(history.ActivityTaskScheduledEventAttributes(activity_id="a")) + res = m.request_cancel() + + assert res is True + assert completed.done() is False + assert m.get_decision() == decision.Decision( + request_cancel_activity_task_decision_attributes=RequestCancelActivityTaskDecisionAttributes( + activity_id="a" + ) + ) + + +async def test_activity_state_machine_completed(): + attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + completed = DecisionFuture[Payload]() + m = ActivityStateMachine(attrs, completed) + + m.handle_scheduled(history.ActivityTaskScheduledEventAttributes(activity_id="a")) + m.handle_started(history.ActivityTaskStartedEventAttributes()) + m.handle_completed( + history.ActivityTaskCompletedEventAttributes(result=Payload(data=b"result")) + ) + + assert completed.done() is True + assert m.get_decision() is None + assert completed.result() == Payload(data=b"result") + + +async def test_activity_state_machine_timeout(): + attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + completed = DecisionFuture[Payload]() + m = ActivityStateMachine(attrs, completed) + + m.handle_scheduled(history.ActivityTaskScheduledEventAttributes(activity_id="a")) + m.handle_timeout( + history.ActivityTaskTimedOutEventAttributes( + details=Payload(data="error message".encode()) + ) + ) + + assert completed.done() is True + assert m.get_decision() is None + with pytest.raises(ActivityFailure, match="error message"): + completed.result() + + +async def test_activity_state_machine_failed(): + attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + completed = DecisionFuture[Payload]() + m = ActivityStateMachine(attrs, completed) + + m.handle_scheduled(history.ActivityTaskScheduledEventAttributes(activity_id="a")) + m.handle_started(history.ActivityTaskStartedEventAttributes()) + m.handle_failed( + history.ActivityTaskFailedEventAttributes( + failure=Failure(reason="error message") + ) + ) + + assert completed.done() is True + assert m.get_decision() is None + with pytest.raises(ActivityFailure, match="error message"): + completed.result() + + +async def test_activity_state_machine_cancel_confirmed(): + attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + completed = DecisionFuture[Payload]() + m = ActivityStateMachine(attrs, completed) + + m.handle_scheduled(history.ActivityTaskScheduledEventAttributes(activity_id="a")) + m.request_cancel() + m.handle_cancel_requested(history.ActivityTaskCancelRequestedEventAttributes()) + + assert m.get_decision() is None + assert completed.done() is False + assert m.get_decision() is None + + +async def test_activity_state_machine_complete_after_cancel(): + attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + completed = DecisionFuture[Payload]() + m = ActivityStateMachine(attrs, completed) + + m.handle_scheduled(history.ActivityTaskScheduledEventAttributes(activity_id="a")) + m.request_cancel() + m.handle_cancel_requested(history.ActivityTaskCancelRequestedEventAttributes()) + m.handle_completed( + history.ActivityTaskCompletedEventAttributes(result=Payload(data=b"result")) + ) + + assert m.get_decision() is None + assert completed.done() is True + assert completed.result() == Payload(data=b"result") + + +async def test_activity_state_machine_cancel_accepted(): + attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + completed = DecisionFuture[Payload]() + m = ActivityStateMachine(attrs, completed) + + m.handle_scheduled(history.ActivityTaskScheduledEventAttributes(activity_id="a")) + m.request_cancel() + m.handle_cancel_requested(history.ActivityTaskCancelRequestedEventAttributes()) + m.handle_canceled( + history.ActivityTaskCanceledEventAttributes( + details=Payload(data="error message".encode()) + ) + ) + + assert m.get_decision() is None + assert completed.done() is True + assert completed.cancelled() is True + with pytest.raises(CancelledError, match="error message"): + completed.result() + + +async def test_activity_state_machine_cancel_failed(): + attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + completed = DecisionFuture[Payload]() + m = ActivityStateMachine(attrs, completed) + + m.handle_scheduled(history.ActivityTaskScheduledEventAttributes(activity_id="a")) + m.request_cancel() + m.handle_cancel_requested(history.ActivityTaskCancelRequestedEventAttributes()) + m.handle_cancel_failed(history.RequestCancelActivityTaskFailedEventAttributes()) + + assert m.get_decision() is None + assert completed.done() is False + assert completed.cancelled() is False + + +async def test_activity_state_machine_completed_after_cancel_failed(): + attrs = decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + completed = DecisionFuture[Payload]() + m = ActivityStateMachine(attrs, completed) + + m.handle_scheduled(history.ActivityTaskScheduledEventAttributes(activity_id="a")) + m.request_cancel() + m.handle_cancel_requested(history.ActivityTaskCancelRequestedEventAttributes()) + m.handle_cancel_failed(history.RequestCancelActivityTaskFailedEventAttributes()) + m.handle_completed( + history.ActivityTaskCompletedEventAttributes(result=Payload(data=b"result")) + ) + + assert m.get_decision() is None + assert completed.result() == Payload(data=b"result") + assert completed.done() is True + assert completed.cancelled() is False diff --git a/tests/cadence/_internal/workflow/statemachine/test_decision_manager.py b/tests/cadence/_internal/workflow/statemachine/test_decision_manager.py new file mode 100644 index 0000000..d0d862b --- /dev/null +++ b/tests/cadence/_internal/workflow/statemachine/test_decision_manager.py @@ -0,0 +1,191 @@ +from asyncio import CancelledError + +import pytest + +from cadence._internal.workflow.statemachine.decision_manager import DecisionManager +from cadence.api.v1 import history, decision +from cadence.api.v1.common_pb2 import Payload + + +async def test_activity_dispatch(): + decisions = DecisionManager() + + activity_result = decisions.schedule_activity( + decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + ) + decisions.handle_history_event(activity_scheduled(1, "a")) + decisions.handle_history_event(activity_started(2, 1)) + decisions.handle_history_event(activity_completed(3, 1, Payload(data=b"completed"))) + + assert activity_result.done() is True + assert activity_result.result() == Payload(data=b"completed") + + +async def test_simple_cancellation(): + decisions = DecisionManager() + + activity_result = decisions.schedule_activity( + decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + ) + activity_result.cancel() + + assert activity_result.done() is True + assert activity_result.cancelled() is True + + +async def test_cancellation_not_immediate(): + decisions = DecisionManager() + + activity_result = decisions.schedule_activity( + decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + ) + decisions.handle_history_event(activity_scheduled(1, "a")) + activity_result.cancel() + + assert activity_result.done() is False + assert activity_result.cancelled() is False + + +async def test_cancellation_completed(): + decisions = DecisionManager() + + activity_result = decisions.schedule_activity( + decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + ) + decisions.handle_history_event(activity_scheduled(1, "a")) + activity_result.cancel() + decisions.handle_history_event( + history.HistoryEvent( + event_id=2, + activity_task_cancel_requested_event_attributes=history.ActivityTaskCancelRequestedEventAttributes( + activity_id="a" + ), + ) + ) + decisions.handle_history_event( + history.HistoryEvent( + event_id=3, + activity_task_canceled_event_attributes=history.ActivityTaskCanceledEventAttributes( + scheduled_event_id=1, details=Payload(data=b"oh no") + ), + ) + ) + + assert activity_result.done() is True + assert activity_result.cancelled() is True + with pytest.raises(CancelledError, match="oh no"): + activity_result.result() + + +async def test_collect_decisions(): + decisions = DecisionManager() + + activity1 = decisions.schedule_activity( + decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + ) + activity2 = decisions.schedule_activity( + decision.ScheduleActivityTaskDecisionAttributes(activity_id="b") + ) + + # Order matters + assert decisions.collect_pending_decisions() == [ + decision.Decision( + schedule_activity_task_decision_attributes=decision.ScheduleActivityTaskDecisionAttributes( + activity_id="a" + ) + ), + decision.Decision( + schedule_activity_task_decision_attributes=decision.ScheduleActivityTaskDecisionAttributes( + activity_id="b" + ) + ), + ] + assert activity1.done() is False + assert activity2.done() is False + + +async def test_collect_decisions_ignore_empty(): + decisions = DecisionManager() + + _ = decisions.schedule_activity( + decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + ) + decisions.handle_history_event(activity_scheduled(1, "a")) + + assert decisions.collect_pending_decisions() == [] + + +async def test_collection_decisions_reordering(): + # Decisions should be emitted in the order that they happened within the workflow + decisions = DecisionManager() + + activity1 = decisions.schedule_activity( + decision.ScheduleActivityTaskDecisionAttributes(activity_id="a") + ) + activity2 = decisions.schedule_activity( + decision.ScheduleActivityTaskDecisionAttributes(activity_id="b") + ) + + assert decisions.collect_pending_decisions() == [ + decision.Decision( + schedule_activity_task_decision_attributes=decision.ScheduleActivityTaskDecisionAttributes( + activity_id="a" + ) + ), + decision.Decision( + schedule_activity_task_decision_attributes=decision.ScheduleActivityTaskDecisionAttributes( + activity_id="b" + ) + ), + ] + + decisions.handle_history_event(activity_scheduled(1, "a")) + decisions.handle_history_event(activity_scheduled(2, "b")) + # cancel them in reverse order + activity2.cancel() + activity1.cancel() + + # Order matters + assert decisions.collect_pending_decisions() == [ + decision.Decision( + request_cancel_activity_task_decision_attributes=decision.RequestCancelActivityTaskDecisionAttributes( + activity_id="b" + ) + ), + decision.Decision( + request_cancel_activity_task_decision_attributes=decision.RequestCancelActivityTaskDecisionAttributes( + activity_id="a" + ) + ), + ] + assert activity1.done() is False + assert activity2.done() is False + + +def activity_scheduled(event_id: int, activity_id: str) -> history.HistoryEvent: + return history.HistoryEvent( + event_id=event_id, + activity_task_scheduled_event_attributes=history.ActivityTaskScheduledEventAttributes( + activity_id=activity_id + ), + ) + + +def activity_started(event_id: int, scheduled_id: int) -> history.HistoryEvent: + return history.HistoryEvent( + event_id=event_id, + activity_task_started_event_attributes=history.ActivityTaskStartedEventAttributes( + scheduled_event_id=scheduled_id + ), + ) + + +def activity_completed( + event_id: int, scheduled_id: int, result: Payload +) -> history.HistoryEvent: + return history.HistoryEvent( + event_id=event_id, + activity_task_completed_event_attributes=history.ActivityTaskCompletedEventAttributes( + scheduled_event_id=scheduled_id, result=result + ), + ) diff --git a/tests/cadence/_internal/workflow/statemachine/test_timer_state_machine.py b/tests/cadence/_internal/workflow/statemachine/test_timer_state_machine.py new file mode 100644 index 0000000..0ae01a3 --- /dev/null +++ b/tests/cadence/_internal/workflow/statemachine/test_timer_state_machine.py @@ -0,0 +1,72 @@ +from cadence._internal.workflow.statemachine.decision_state_machine import ( + DecisionFuture, +) +from cadence._internal.workflow.statemachine.timer_state_machine import ( + TimerStateMachine, +) +from cadence.api.v1 import decision, history +from cadence.api.v1.decision_pb2 import CancelTimerDecisionAttributes + + +### These tests have to be async because they rely on the presence of an eventloop + + +async def test_timer_state_machine_started(): + attrs = decision.StartTimerDecisionAttributes(timer_id="t-cbs") + completed = DecisionFuture[None]() + m = TimerStateMachine(attrs, completed) + + assert completed.done() is False + assert m.get_decision() == decision.Decision(start_timer_decision_attributes=attrs) + + +async def test_timer_state_machine_cancel_before_sent(): + attrs = decision.StartTimerDecisionAttributes(timer_id="t") + completed = DecisionFuture[None]() + m = TimerStateMachine(attrs, completed) + + assert m.request_cancel() is True + + assert completed.done() is True + assert m.get_decision() is None + + +async def test_timer_state_machine_cancel_after_initiated(): + attrs = decision.StartTimerDecisionAttributes(timer_id="t") + completed = DecisionFuture[None]() + m = TimerStateMachine(attrs, completed) + + m.handle_started(history.TimerStartedEventAttributes(timer_id="t")) + res = m.request_cancel() + + assert res is True + assert completed.done() is True + assert m.get_decision() == decision.Decision( + cancel_timer_decision_attributes=CancelTimerDecisionAttributes(timer_id="t") + ) + + +async def test_timer_state_machine_fired(): + attrs = decision.StartTimerDecisionAttributes(timer_id="t") + completed = DecisionFuture[None]() + m = TimerStateMachine(attrs, completed) + + m.handle_started(history.TimerStartedEventAttributes(timer_id="t")) + m.handle_fired(history.TimerFiredEventAttributes()) + + assert completed.done() is True + assert m.get_decision() is None + + +async def test_timer_state_machine_cancel_after_fired(): + attrs = decision.StartTimerDecisionAttributes(timer_id="t") + completed = DecisionFuture[None]() + m = TimerStateMachine(attrs, completed) + + m.handle_started(history.TimerStartedEventAttributes(timer_id="t")) + m.handle_fired(history.TimerFiredEventAttributes()) + res = m.request_cancel() + + assert res is False + assert completed.done() is True + assert m.get_decision() is None diff --git a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py index 768768e..3aa5d44 100644 --- a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py +++ b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py @@ -27,7 +27,7 @@ def mock_client(self): client = Mock(spec=Client) client.domain = "test-domain" client.data_converter = Mock() - client.data_converter.from_data = AsyncMock(return_value=["test-input"]) + client.data_converter.from_data = Mock(return_value=["test-input"]) return client @pytest.fixture @@ -38,6 +38,7 @@ def workflow_info(self): workflow_domain="test-domain", workflow_id="test-workflow-id", workflow_run_id="test-run-id", + workflow_task_list="test-task-list", ) @pytest.fixture @@ -184,7 +185,7 @@ async def test_extract_workflow_input_success(self, workflow_engine, mock_client decision_task = self.create_mock_decision_task() # Extract workflow input - input_data = await workflow_engine._extract_workflow_input(decision_task) + input_data = workflow_engine._extract_workflow_input(decision_task) # Verify the input was extracted assert input_data == "test-input" @@ -200,7 +201,7 @@ async def test_extract_workflow_input_no_history( # No history set # Extract workflow input - input_data = await workflow_engine._extract_workflow_input(decision_task) + input_data = workflow_engine._extract_workflow_input(decision_task) # Verify no input was extracted assert input_data is None @@ -230,7 +231,7 @@ async def test_extract_workflow_input_no_started_event( decision_task.history.CopyFrom(history) # Extract workflow input - input_data = await workflow_engine._extract_workflow_input(decision_task) + input_data = workflow_engine._extract_workflow_input(decision_task) # Verify no input was extracted assert input_data is None @@ -248,7 +249,7 @@ async def test_extract_workflow_input_deserialization_error( ) # Extract workflow input - input_data = await workflow_engine._extract_workflow_input(decision_task) + input_data = workflow_engine._extract_workflow_input(decision_task) # Verify no input was extracted due to error assert input_data is None diff --git a/tests/cadence/common_activities.py b/tests/cadence/common_activities.py index a183e42..b6904b4 100644 --- a/tests/cadence/common_activities.py +++ b/tests/cadence/common_activities.py @@ -28,6 +28,11 @@ async def async_fn() -> None: pass +@activity.defn() +async def async_echo(incoming: str) -> str: + return incoming + + class Activities: @activity.defn() def echo_sync(self, incoming: str) -> str: diff --git a/tests/cadence/data_converter_test.py b/tests/cadence/data_converter_test.py index 91a2ff4..b6a1a7b 100644 --- a/tests/cadence/data_converter_test.py +++ b/tests/cadence/data_converter_test.py @@ -72,11 +72,11 @@ class _TestDataClass: ), ], ) -async def test_data_converter_from_data( +def test_data_converter_from_data( json: str, types: list[Type], expected: list[Any] ) -> None: converter = DefaultDataConverter() - actual = await converter.from_data(Payload(data=json.encode()), types) + actual = converter.from_data(Payload(data=json.encode()), types) assert expected == actual @@ -97,8 +97,8 @@ async def test_data_converter_from_data( ), ], ) -async def test_data_converter_to_data(values: list[Any], expected: str) -> None: +def test_data_converter_to_data(values: list[Any], expected: str) -> None: converter = DefaultDataConverter() converter._encoder = json.Encoder(order="deterministic") - actual = await converter.to_data(values) + actual = converter.to_data(values) assert actual.data.decode() == expected diff --git a/tests/cadence/test_activity.py b/tests/cadence/test_activity.py new file mode 100644 index 0000000..1fcfd5a --- /dev/null +++ b/tests/cadence/test_activity.py @@ -0,0 +1,51 @@ +from cadence.activity import ActivityParameter, ExecutionStrategy +from tests.cadence.common_activities import ( + simple_fn, + async_echo, + ActivityInterface, + Activities, +) + + +def test_sync(): + definition = simple_fn + + assert definition.name == "simple_fn" + assert definition.params == [] + assert definition.result_type is None.__class__ + assert definition.strategy == ExecutionStrategy.THREAD_POOL + + +def test_async(): + definition = async_echo + + assert definition.name == "async_echo" + assert definition.params == [ + ActivityParameter( + name="incoming", type_hint=str, has_default=False, default_value=None + ) + ] + assert definition.result_type is str + assert definition.strategy == ExecutionStrategy.ASYNC + + +def test_interface(): + definition = ActivityInterface.do_something + + assert definition.name == "ActivityInterface.do_something" + assert definition.params == [] + assert definition.result_type is str + assert definition.strategy == ExecutionStrategy.THREAD_POOL + + +def test_class_async(): + definition = Activities.echo_async + + assert definition.name == "Activities.echo_async" + assert definition.params == [ + ActivityParameter( + name="incoming", type_hint=str, has_default=False, default_value=None + ) + ] + assert definition.result_type is str + assert definition.strategy == ExecutionStrategy.ASYNC diff --git a/tests/cadence/test_client_workflow.py b/tests/cadence/test_client_workflow.py index 9f6a7c3..cdf7a2c 100644 --- a/tests/cadence/test_client_workflow.py +++ b/tests/cadence/test_client_workflow.py @@ -77,7 +77,7 @@ async def test_build_request_with_string_workflow(self, mock_client): task_start_to_close_timeout=timedelta(seconds=10), ) - request = await client._build_start_workflow_request( + request = client._build_start_workflow_request( "TestWorkflow", ("arg1", "arg2"), options ) @@ -110,7 +110,7 @@ def test_workflow(): task_start_to_close_timeout=timedelta(seconds=30), ) - request = await client._build_start_workflow_request(test_workflow, (), options) + request = client._build_start_workflow_request(test_workflow, (), options) assert request.workflow_type.name == "test_workflow" @@ -125,9 +125,7 @@ async def test_build_request_generates_workflow_id(self, mock_client): task_start_to_close_timeout=timedelta(seconds=30), ) - request = await client._build_start_workflow_request( - "TestWorkflow", (), options - ) + request = client._build_start_workflow_request("TestWorkflow", (), options) assert request.workflow_id != "" # Verify it's a valid UUID @@ -192,7 +190,7 @@ async def test_build_request_with_input_args(self, mock_client): task_start_to_close_timeout=timedelta(seconds=30), ) - request = await client._build_start_workflow_request( + request = client._build_start_workflow_request( "TestWorkflow", ("arg1", 42, {"key": "value"}), options ) @@ -211,9 +209,7 @@ async def test_build_request_with_timeouts(self, mock_client): task_start_to_close_timeout=timedelta(seconds=10), ) - request = await client._build_start_workflow_request( - "TestWorkflow", (), options - ) + request = client._build_start_workflow_request("TestWorkflow", (), options) assert request.HasField("execution_start_to_close_timeout") assert request.HasField("task_start_to_close_timeout") @@ -234,9 +230,7 @@ async def test_build_request_with_cron_schedule(self, mock_client): cron_schedule="0 * * * *", ) - request = await client._build_start_workflow_request( - "TestWorkflow", (), options - ) + request = client._build_start_workflow_request("TestWorkflow", (), options) assert request.cron_schedule == "0 * * * *" @@ -260,13 +254,13 @@ async def test_start_workflow_success(self, mock_client): client._workflow_stub = mock_client.workflow_stub # Mock the internal method to avoid full request building - async def mock_build_request(workflow, args, options): + def mock_build_request(workflow, args, options): request = StartWorkflowExecutionRequest() request.workflow_id = "test-workflow-id" request.domain = "test-domain" return request - client._build_start_workflow_request = AsyncMock(side_effect=mock_build_request) + client._build_start_workflow_request = Mock(side_effect=mock_build_request) execution = await client.start_workflow( "TestWorkflow", @@ -297,7 +291,7 @@ async def test_start_workflow_grpc_error(self, mock_client): client._workflow_stub = mock_client.workflow_stub # Mock the internal method - client._build_start_workflow_request = AsyncMock( + client._build_start_workflow_request = Mock( return_value=StartWorkflowExecutionRequest() ) @@ -325,14 +319,14 @@ async def test_start_workflow_with_kwargs(self, mock_client): # Mock the internal method to capture options captured_options = None - async def mock_build_request(workflow, args, options): + def mock_build_request(workflow, args, options): nonlocal captured_options captured_options = options request = StartWorkflowExecutionRequest() request.workflow_id = "test-workflow-id" return request - client._build_start_workflow_request = AsyncMock(side_effect=mock_build_request) + client._build_start_workflow_request = Mock(side_effect=mock_build_request) await client.start_workflow( "TestWorkflow", @@ -366,14 +360,14 @@ async def test_start_workflow_with_default_task_timeout(self, mock_client): # Mock the internal method to capture options captured_options = None - async def mock_build_request(workflow, args, options): + def mock_build_request(workflow, args, options): nonlocal captured_options captured_options = options request = StartWorkflowExecutionRequest() request.workflow_id = "test-workflow-id" return request - client._build_start_workflow_request = AsyncMock(side_effect=mock_build_request) + client._build_start_workflow_request = Mock(side_effect=mock_build_request) await client.start_workflow( "TestWorkflow", diff --git a/tests/cadence/worker/test_base_task_handler.py b/tests/cadence/worker/test_base_task_handler.py index 6d1077c..8196240 100644 --- a/tests/cadence/worker/test_base_task_handler.py +++ b/tests/cadence/worker/test_base_task_handler.py @@ -48,7 +48,7 @@ def test_initialization(self): ) assert handler._client == client - assert handler._task_list == "test_task_list" + assert handler.task_list == "test_task_list" assert handler._identity == "test_identity" assert handler._options == {"option1": "value1", "option2": "value2"} diff --git a/tests/cadence/worker/test_decision_task_handler.py b/tests/cadence/worker/test_decision_task_handler.py index b7c0957..bf48cda 100644 --- a/tests/cadence/worker/test_decision_task_handler.py +++ b/tests/cadence/worker/test_decision_task_handler.py @@ -76,7 +76,7 @@ def test_initialization(self, mock_client, mock_registry): ) assert handler._client == mock_client - assert handler._task_list == "test_task_list" + assert handler.task_list == "test_task_list" assert handler._identity == "test_identity" assert handler._registry == mock_registry assert handler._options == {"option1": "value1"} @@ -441,6 +441,7 @@ async def run(self): "workflow_domain": "test_domain", "workflow_id": "test_workflow_id", "workflow_run_id": "test_run_id", + "workflow_task_list": "test_task_list", } # Verify WorkflowEngine was created with correct parameters diff --git a/tests/integration_tests/test_client.py b/tests/integration_tests/test_client.py index 3f96b9b..5b4e785 100644 --- a/tests/integration_tests/test_client.py +++ b/tests/integration_tests/test_client.py @@ -1,14 +1,14 @@ -from datetime import timedelta import pytest +from datetime import timedelta + from cadence.api.v1.service_domain_pb2 import ( DescribeDomainRequest, DescribeDomainResponse, ) -from cadence.api.v1.service_workflow_pb2 import DescribeWorkflowExecutionRequest -from cadence.api.v1.common_pb2 import WorkflowExecution - from cadence.error import EntityNotExistsError from tests.integration_tests.helper import CadenceHelper, DOMAIN_NAME +from cadence.api.v1.service_workflow_pb2 import DescribeWorkflowExecutionRequest +from cadence.api.v1.common_pb2 import WorkflowExecution @pytest.mark.usefixtures("helper") @@ -49,7 +49,6 @@ async def test_worker_stub_accessible(helper: CadenceHelper): @pytest.mark.usefixtures("helper") async def test_workflow_stub_start_and_describe(helper: CadenceHelper): """Comprehensive test for workflow start and describe operations. - This integration test verifies: 1. Starting a workflow execution via workflow_stub 2. Describing the workflow execution