From f7c7f216155c793f8bda562909cebfa32acaa6fe Mon Sep 17 00:00:00 2001 From: Shijie Sheng Date: Wed, 5 Nov 2025 13:40:53 -0800 Subject: [PATCH 1/2] fix(workflow worker): move out event iterator out of WorkflowEngine Signed-off-by: Shijie Sheng --- cadence/_internal/workflow/context.py | 10 +- .../workflow/decision_events_iterator.py | 149 ++++++++---------- cadence/_internal/workflow/workflow_engine.py | 15 +- cadence/worker/_decision_task_handler.py | 10 +- cadence/workflow.py | 8 +- .../workflow/test_decision_events_iterator.py | 77 ++------- .../test_workflow_engine_integration.py | 7 +- .../worker/test_decision_task_handler.py | 13 +- .../worker/test_task_handler_integration.py | 10 +- 9 files changed, 118 insertions(+), 181 deletions(-) diff --git a/cadence/_internal/workflow/context.py b/cadence/_internal/workflow/context.py index d008ce4..d3c926f 100644 --- a/cadence/_internal/workflow/context.py +++ b/cadence/_internal/workflow/context.py @@ -7,7 +7,6 @@ from cadence.api.v1.common_pb2 import ActivityType from cadence.api.v1.decision_pb2 import ScheduleActivityTaskDecisionAttributes from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind -from cadence.client import Client from cadence.data_converter import DataConverter from cadence.workflow import WorkflowContext, WorkflowInfo, ResultType, ActivityOptions @@ -15,12 +14,10 @@ class Context(WorkflowContext): def __init__( self, - client: Client, info: WorkflowInfo, decision_helper: DecisionsHelper, decision_manager: DecisionManager, ): - self._client = client self._info = info self._replay_mode = True self._replay_current_time_milliseconds: Optional[int] = None @@ -30,11 +27,8 @@ def __init__( def info(self) -> WorkflowInfo: return self._info - def client(self) -> Client: - return self._client - def data_converter(self) -> DataConverter: - return self._client.data_converter + return self.info().data_converter async def execute_activity( self, @@ -80,7 +74,7 @@ async def execute_activity( schedule_attributes = ScheduleActivityTaskDecisionAttributes( activity_id=activity_id, activity_type=ActivityType(name=activity), - domain=self._client.domain, + domain=self.info().workflow_domain, task_list=TaskList(kind=TaskListKind.TASK_LIST_KIND_NORMAL, name=task_list), input=activity_input, schedule_to_close_timeout=_round_to_nearest_second(schedule_to_close), diff --git a/cadence/_internal/workflow/decision_events_iterator.py b/cadence/_internal/workflow/decision_events_iterator.py index b92b8ca..2657aa4 100644 --- a/cadence/_internal/workflow/decision_events_iterator.py +++ b/cadence/_internal/workflow/decision_events_iterator.py @@ -11,8 +11,6 @@ from cadence.api.v1.history_pb2 import HistoryEvent from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse -from cadence.client import Client -from cadence._internal.workflow.history_event_iterator import iterate_history_events @dataclass @@ -55,99 +53,36 @@ class DecisionEventsIterator: into decision iterations for proper workflow replay and execution. """ - def __init__(self, decision_task: PollForDecisionTaskResponse, client: Client): - self._client = client + def __init__( + self, decision_task: PollForDecisionTaskResponse, events: List[HistoryEvent] + ): self._decision_task = decision_task - self._events: List[HistoryEvent] = [] - self._event_index = 0 + self._events: List[HistoryEvent] = events self._decision_task_started_event: Optional[HistoryEvent] = None self._next_decision_event_id = 1 self._replay = True self._replay_current_time_milliseconds: Optional[int] = None - self._initialized = False - - @staticmethod - def _is_decision_task_started(event: HistoryEvent) -> bool: - """Check if event is DecisionTaskStarted.""" - return hasattr( - event, "decision_task_started_event_attributes" - ) and event.HasField("decision_task_started_event_attributes") - - @staticmethod - def _is_decision_task_completed(event: HistoryEvent) -> bool: - """Check if event is DecisionTaskCompleted.""" - return hasattr( - event, "decision_task_completed_event_attributes" - ) and event.HasField("decision_task_completed_event_attributes") - - @staticmethod - def _is_decision_task_failed(event: HistoryEvent) -> bool: - """Check if event is DecisionTaskFailed.""" - return hasattr( - event, "decision_task_failed_event_attributes" - ) and event.HasField("decision_task_failed_event_attributes") - - @staticmethod - def _is_decision_task_timed_out(event: HistoryEvent) -> bool: - """Check if event is DecisionTaskTimedOut.""" - return hasattr( - event, "decision_task_timed_out_event_attributes" - ) and event.HasField("decision_task_timed_out_event_attributes") - - @staticmethod - def _is_marker_recorded(event: HistoryEvent) -> bool: - """Check if event is MarkerRecorded.""" - return hasattr(event, "marker_recorded_event_attributes") and event.HasField( - "marker_recorded_event_attributes" - ) - @staticmethod - def _is_decision_task_completion(event: HistoryEvent) -> bool: - """Check if event is any kind of decision task completion.""" - return ( - DecisionEventsIterator._is_decision_task_completed(event) - or DecisionEventsIterator._is_decision_task_failed(event) - or DecisionEventsIterator._is_decision_task_timed_out(event) - ) - - async def _ensure_initialized(self): - """Initialize events list using the existing iterate_history_events.""" - if not self._initialized: - # Use existing iterate_history_events function - events_iterator = iterate_history_events(self._decision_task, self._client) - self._events = [event async for event in events_iterator] - self._initialized = True - - # Find first decision task started event - for i, event in enumerate(self._events): - if self._is_decision_task_started(event): - self._event_index = i - break + self._event_index = 0 + # Find first decision task started event + for i, event in enumerate(self._events): + if _is_decision_task_started(event): + self._event_index = i + break async def has_next_decision_events(self) -> bool: - """Check if there are more decision events to process.""" - await self._ensure_initialized() - # Look for the next DecisionTaskStarted event from current position for i in range(self._event_index, len(self._events)): - if self._is_decision_task_started(self._events[i]): + if _is_decision_task_started(self._events[i]): return True return False async def next_decision_events(self) -> DecisionEvents: - """ - Get the next set of decision events. - - This method processes events starting from a DecisionTaskStarted event - until the corresponding DecisionTaskCompleted/Failed/TimedOut event. - """ - await self._ensure_initialized() - # Find next DecisionTaskStarted event start_index = None for i in range(self._event_index, len(self._events)): - if self._is_decision_task_started(self._events[i]): + if _is_decision_task_started(self._events[i]): start_index = i break @@ -182,9 +117,9 @@ async def next_decision_events(self) -> DecisionEvents: decision_events.events.append(event) # Categorize the event - if self._is_marker_recorded(event): + if _is_marker_recorded(event): decision_events.markers.append(event) - elif self._is_decision_task_completion(event): + elif _is_decision_task_completion(event): # This marks the end of this decision iteration self._process_decision_completion_event(event, decision_events) current_index += 1 # Move past this event @@ -206,7 +141,7 @@ async def next_decision_events(self) -> DecisionEvents: # Check directly without calling has_next_decision_events to avoid recursion has_more = False for i in range(self._event_index, len(self._events)): - if self._is_decision_task_started(self._events[i]): + if _is_decision_task_started(self._events[i]): has_more = True break @@ -261,16 +196,16 @@ async def __anext__(self) -> DecisionEvents: def is_decision_event(event: HistoryEvent) -> bool: """Check if an event is a decision-related event.""" return ( - DecisionEventsIterator._is_decision_task_started(event) - or DecisionEventsIterator._is_decision_task_completed(event) - or DecisionEventsIterator._is_decision_task_failed(event) - or DecisionEventsIterator._is_decision_task_timed_out(event) + _is_decision_task_started(event) + or _is_decision_task_completed(event) + or _is_decision_task_failed(event) + or _is_decision_task_timed_out(event) ) def is_marker_event(event: HistoryEvent) -> bool: """Check if an event is a marker event.""" - return DecisionEventsIterator._is_marker_recorded(event) + return _is_marker_recorded(event) def extract_event_timestamp_millis(event: HistoryEvent) -> Optional[int]: @@ -279,3 +214,47 @@ def extract_event_timestamp_millis(event: HistoryEvent) -> Optional[int]: seconds = getattr(event.event_time, "seconds", 0) return seconds * 1000 if seconds > 0 else None return None + + +def _is_decision_task_started(event: HistoryEvent) -> bool: + """Check if event is DecisionTaskStarted.""" + return hasattr(event, "decision_task_started_event_attributes") and event.HasField( + "decision_task_started_event_attributes" + ) + + +def _is_decision_task_completed(event: HistoryEvent) -> bool: + """Check if event is DecisionTaskCompleted.""" + return hasattr( + event, "decision_task_completed_event_attributes" + ) and event.HasField("decision_task_completed_event_attributes") + + +def _is_decision_task_failed(event: HistoryEvent) -> bool: + """Check if event is DecisionTaskFailed.""" + return hasattr(event, "decision_task_failed_event_attributes") and event.HasField( + "decision_task_failed_event_attributes" + ) + + +def _is_decision_task_timed_out(event: HistoryEvent) -> bool: + """Check if event is DecisionTaskTimedOut.""" + return hasattr( + event, "decision_task_timed_out_event_attributes" + ) and event.HasField("decision_task_timed_out_event_attributes") + + +def _is_marker_recorded(event: HistoryEvent) -> bool: + """Check if event is MarkerRecorded.""" + return hasattr(event, "marker_recorded_event_attributes") and event.HasField( + "marker_recorded_event_attributes" + ) + + +def _is_decision_task_completion(event: HistoryEvent) -> bool: + """Check if event is any kind of decision task completion.""" + return ( + _is_decision_task_completed(event) + or _is_decision_task_failed(event) + or _is_decision_task_timed_out(event) + ) diff --git a/cadence/_internal/workflow/workflow_engine.py b/cadence/_internal/workflow/workflow_engine.py index 997a287..8024a75 100644 --- a/cadence/_internal/workflow/workflow_engine.py +++ b/cadence/_internal/workflow/workflow_engine.py @@ -8,7 +8,6 @@ from cadence._internal.workflow.decision_events_iterator import DecisionEventsIterator from cadence._internal.workflow.statemachine.decision_manager import DecisionManager from cadence.api.v1.decision_pb2 import Decision -from cadence.client import Client from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse from cadence.workflow import WorkflowInfo @@ -21,16 +20,14 @@ class DecisionResult: class WorkflowEngine: - def __init__(self, info: WorkflowInfo, client: Client, workflow_definition=None): + def __init__(self, info: WorkflowInfo, workflow_definition=None): self._workflow_definition = workflow_definition self._workflow_instance = None if workflow_definition: self._workflow_instance = workflow_definition.cls() self._decision_manager = DecisionManager() self._decisions_helper = DecisionsHelper() - self._context = Context( - client, info, self._decisions_helper, self._decision_manager - ) + self._context = Context(info, self._decisions_helper, self._decision_manager) self._is_workflow_complete = False async def process_decision( @@ -65,7 +62,7 @@ async def process_decision( with self._context._activate(): # Create DecisionEventsIterator for structured event processing events_iterator = DecisionEventsIterator( - decision_task, self._context.client() + decision_task, self._context.info().workflow_events ) # Process decision events using iterator-driven approach @@ -360,10 +357,8 @@ def _extract_workflow_input( # Deserialize the input using the client's data converter try: # Use from_data method with a single type hint of None (no type conversion) - input_data_list = ( - self._context.client().data_converter.from_data( - started_attrs.input, [None] - ) + input_data_list = self._context.data_converter().from_data( + started_attrs.input, [None] ) input_data = input_data_list[0] if input_data_list else None logger.debug(f"Extracted workflow input: {input_data}") diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py index 4710bc1..3ef0040 100644 --- a/cadence/worker/_decision_task_handler.py +++ b/cadence/worker/_decision_task_handler.py @@ -2,6 +2,7 @@ import threading from typing import Dict, Tuple +from cadence._internal.workflow.history_event_iterator import iterate_history_events from cadence.api.v1.common_pb2 import Payload from cadence.api.v1.service_worker_pb2 import ( PollForDecisionTaskResponse, @@ -102,6 +103,12 @@ async def _handle_task_implementation( ) raise KeyError(f"Workflow type '{workflow_type_name}' not found") + # fetch full workflow history + # TODO sticky cache + workflow_events = [ + event async for event in iterate_history_events(task, self._client) + ] + # Create workflow info and get or create workflow engine from cache workflow_info = WorkflowInfo( workflow_type=workflow_type_name, @@ -109,6 +116,8 @@ async def _handle_task_implementation( workflow_id=workflow_id, workflow_run_id=run_id, workflow_task_list=self.task_list, + data_converter=self._client.data_converter, + workflow_events=workflow_events, ) # Use thread-safe cache to get or create workflow engine @@ -118,7 +127,6 @@ async def _handle_task_implementation( if workflow_engine is None: workflow_engine = WorkflowEngine( info=workflow_info, - client=self._client, workflow_definition=workflow_definition, ) self._workflow_engines[cache_key] = workflow_engine diff --git a/cadence/workflow.py b/cadence/workflow.py index 913ebd1..68911e3 100644 --- a/cadence/workflow.py +++ b/cadence/workflow.py @@ -5,6 +5,7 @@ from datetime import timedelta from typing import ( Callable, + List, cast, Optional, Union, @@ -17,7 +18,7 @@ ) import inspect -from cadence.client import Client +from cadence.api.v1.history_pb2 import HistoryEvent from cadence.data_converter import DataConverter ResultType = TypeVar("ResultType") @@ -169,6 +170,8 @@ class WorkflowInfo: workflow_id: str workflow_run_id: str workflow_task_list: str + workflow_events: List[HistoryEvent] + data_converter: DataConverter class WorkflowContext(ABC): @@ -177,9 +180,6 @@ class WorkflowContext(ABC): @abstractmethod def info(self) -> WorkflowInfo: ... - @abstractmethod - def client(self) -> Client: ... - @abstractmethod def data_converter(self) -> DataConverter: ... diff --git a/tests/cadence/_internal/workflow/test_decision_events_iterator.py b/tests/cadence/_internal/workflow/test_decision_events_iterator.py index 94edef9..1e70661 100644 --- a/tests/cadence/_internal/workflow/test_decision_events_iterator.py +++ b/tests/cadence/_internal/workflow/test_decision_events_iterator.py @@ -4,14 +4,11 @@ """ import pytest -from unittest.mock import Mock, AsyncMock from typing import List from cadence.api.v1.history_pb2 import HistoryEvent, History from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse -from cadence.api.v1.service_workflow_pb2 import GetWorkflowExecutionHistoryResponse from cadence.api.v1.common_pb2 import WorkflowExecution -from cadence.client import Client from google.protobuf.timestamp_pb2 import Timestamp from cadence._internal.workflow.decision_events_iterator import ( @@ -73,16 +70,6 @@ def create_mock_decision_task( return task -@pytest.fixture -def mock_client(): - """Create a mock Cadence client.""" - client = Mock(spec=Client) - client.domain = "test-domain" - client.workflow_stub = Mock() - client.workflow_stub.GetWorkflowExecutionHistory = AsyncMock() - return client - - class TestDecisionEvents: """Test the DecisionEvents class.""" @@ -134,7 +121,7 @@ class TestDecisionEventsIterator: """Test the DecisionEventsIterator class.""" @pytest.mark.asyncio - async def test_single_decision_iteration(self, mock_client): + async def test_single_decision_iteration(self): """Test processing a single decision iteration.""" # Create events for a complete decision iteration events = [ @@ -147,8 +134,7 @@ async def test_single_decision_iteration(self, mock_client): ] decision_task = create_mock_decision_task(events) - iterator = DecisionEventsIterator(decision_task, mock_client) - await iterator._ensure_initialized() + iterator = DecisionEventsIterator(decision_task, events) assert await iterator.has_next_decision_events() @@ -164,7 +150,7 @@ async def test_single_decision_iteration(self, mock_client): assert decision_events.replay_current_time_milliseconds == 1000 * 1000 @pytest.mark.asyncio - async def test_multiple_decision_iterations(self, mock_client): + async def test_multiple_decision_iterations(self): """Test processing multiple decision iterations.""" # Create events for two decision iterations events = [ @@ -177,8 +163,7 @@ async def test_multiple_decision_iterations(self, mock_client): ] decision_task = create_mock_decision_task(events) - iterator = DecisionEventsIterator(decision_task, mock_client) - await iterator._ensure_initialized() + iterator = DecisionEventsIterator(decision_task, events) # First iteration assert await iterator.has_next_decision_events() @@ -196,47 +181,7 @@ async def test_multiple_decision_iterations(self, mock_client): assert not await iterator.has_next_decision_events() @pytest.mark.asyncio - async def test_pagination_support(self, mock_client): - """Test that pagination is handled correctly.""" - # First page events - first_page_events = [ - create_mock_history_event(1, "decision_task_started"), - create_mock_history_event(2, "decision_task_completed"), - ] - - # Second page events - second_page_events = [ - create_mock_history_event(3, "decision_task_started"), - create_mock_history_event(4, "decision_task_completed"), - ] - - # Mock the pagination response - pagination_response = GetWorkflowExecutionHistoryResponse() - pagination_history = History() - pagination_history.events.extend(second_page_events) - pagination_response.history.CopyFrom(pagination_history) - pagination_response.next_page_token = b"" # No more pages - - mock_client.workflow_stub.GetWorkflowExecutionHistory.return_value = ( - pagination_response - ) - - # Create decision task with next page token - decision_task = create_mock_decision_task(first_page_events, b"next-page-token") - iterator = DecisionEventsIterator(decision_task, mock_client) - await iterator._ensure_initialized() - - # Should process both pages - iterations_count = 0 - while await iterator.has_next_decision_events(): - await iterator.next_decision_events() - iterations_count += 1 - - assert iterations_count == 2 - assert mock_client.workflow_stub.GetWorkflowExecutionHistory.called - - @pytest.mark.asyncio - async def test_iterator_protocol(self, mock_client): + async def test_iterator_protocol(self): """Test that DecisionEventsIterator works with Python iterator protocol.""" events = [ create_mock_history_event(1, "decision_task_started"), @@ -246,8 +191,7 @@ async def test_iterator_protocol(self, mock_client): ] decision_task = create_mock_decision_task(events) - iterator = DecisionEventsIterator(decision_task, mock_client) - await iterator._ensure_initialized() + iterator = DecisionEventsIterator(decision_task, events) decision_events_list = [] async for decision_events in iterator: @@ -293,7 +237,7 @@ class TestIntegrationScenarios: """Test real-world integration scenarios.""" @pytest.mark.asyncio - async def test_replay_detection(self, mock_client): + async def test_replay_detection(self): """Test replay mode detection.""" # Simulate a scenario where we have historical events and current events events = [ @@ -306,8 +250,7 @@ async def test_replay_detection(self, mock_client): # Mock the started_event_id to indicate current decision decision_task.started_event_id = 3 - iterator = DecisionEventsIterator(decision_task, mock_client) - await iterator._ensure_initialized() + iterator = DecisionEventsIterator(decision_task, events) # First decision should be replay (but gets set to false when no more events) await iterator.next_decision_events() @@ -319,7 +262,7 @@ async def test_replay_detection(self, mock_client): # (This would need the completion event to trigger the replay mode change) @pytest.mark.asyncio - async def test_complex_workflow_scenario(self, mock_client): + async def test_complex_workflow_scenario(self): """Test a complex workflow with multiple event types.""" events = [ create_mock_history_event(1, "decision_task_started"), @@ -333,7 +276,7 @@ async def test_complex_workflow_scenario(self, mock_client): ] decision_task = create_mock_decision_task(events) - iterator = DecisionEventsIterator(decision_task, mock_client) + iterator = DecisionEventsIterator(decision_task, events) all_decisions = [] async for decision_events in iterator: diff --git a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py index 3aa5d44..7a86ba7 100644 --- a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py +++ b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py @@ -33,12 +33,16 @@ def mock_client(self): @pytest.fixture def workflow_info(self): """Create workflow info.""" + data_converter = Mock() + data_converter.from_data = Mock(return_value=["test-input"]) return WorkflowInfo( workflow_type="test_workflow", workflow_domain="test-domain", workflow_id="test-workflow-id", workflow_run_id="test-run-id", workflow_task_list="test-task-list", + workflow_events = [], + data_converter=data_converter, ) @pytest.fixture @@ -54,11 +58,10 @@ async def weird_name(self, input_data): return WorkflowDefinition.wrap(TestWorkflow, workflow_opts) @pytest.fixture - def workflow_engine(self, mock_client, workflow_info, mock_workflow_definition): + def workflow_engine(self, workflow_info, mock_workflow_definition): """Create a WorkflowEngine instance.""" return WorkflowEngine( info=workflow_info, - client=mock_client, workflow_definition=mock_workflow_definition, ) diff --git a/tests/cadence/worker/test_decision_task_handler.py b/tests/cadence/worker/test_decision_task_handler.py index bf48cda..a3de410 100644 --- a/tests/cadence/worker/test_decision_task_handler.py +++ b/tests/cadence/worker/test_decision_task_handler.py @@ -7,6 +7,7 @@ from unittest.mock import Mock, AsyncMock, patch, PropertyMock from cadence.api.v1.common_pb2 import Payload +from cadence.api.v1.history_pb2 import History from cadence.api.v1.service_worker_pb2 import ( PollForDecisionTaskResponse, RespondDecisionTaskCompletedRequest, @@ -63,6 +64,8 @@ def sample_decision_task(self): # Add the missing attributes that are now accessed directly task.started_event_id = 1 task.attempt = 1 + task.history = History() + task.next_page_token = b"" return task def test_initialization(self, mock_client, mock_registry): @@ -83,7 +86,7 @@ def test_initialization(self, mock_client, mock_registry): @pytest.mark.asyncio async def test_handle_task_implementation_success( - self, handler, sample_decision_task, mock_registry + self, handler: DecisionTaskHandler, sample_decision_task, mock_registry ): """Test successful decision task handling.""" @@ -100,7 +103,6 @@ async def run(self): # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute - mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [Decision()] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) @@ -223,6 +225,8 @@ async def run(self): task1.workflow_type.name = "TestWorkflow" task1.started_event_id = 1 task1.attempt = 1 + task1.history = History() + task1.next_page_token = b"" task2 = Mock(spec=PollForDecisionTaskResponse) task2.task_token = b"test_task_token_2" @@ -233,6 +237,8 @@ async def run(self): task2.workflow_type.name = "TestWorkflow" task2.started_event_id = 2 task2.attempt = 1 + task2.history = History() + task2.next_page_token = b"" # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) @@ -442,11 +448,12 @@ async def run(self): "workflow_id": "test_workflow_id", "workflow_run_id": "test_run_id", "workflow_task_list": "test_task_list", + "data_converter": handler._client.data_converter, + "workflow_events": [], } # Verify WorkflowEngine was created with correct parameters mock_workflow_engine_class.assert_called_once() call_args = mock_workflow_engine_class.call_args assert call_args[1]["info"] is not None - assert call_args[1]["client"] == handler._client assert call_args[1]["workflow_definition"] == workflow_definition diff --git a/tests/cadence/worker/test_task_handler_integration.py b/tests/cadence/worker/test_task_handler_integration.py index 5fec846..3950c46 100644 --- a/tests/cadence/worker/test_task_handler_integration.py +++ b/tests/cadence/worker/test_task_handler_integration.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from unittest.mock import Mock, AsyncMock, patch, PropertyMock +from cadence.api.v1.history_pb2 import History from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse from cadence.client import Client from cadence.worker._decision_task_handler import DecisionTaskHandler @@ -58,6 +59,8 @@ def sample_decision_task(self): # Add the missing attributes that are now accessed directly task.started_event_id = 1 task.attempt = 1 + task.history = History() + task.next_page_token = b"" return task @pytest.mark.asyncio @@ -205,6 +208,8 @@ async def run(self): task1.workflow_type.name = "TestWorkflow" task1.started_event_id = 1 task1.attempt = 1 + task1.history = History() + task1.next_page_token = b"" task2 = Mock(spec=PollForDecisionTaskResponse) task2.task_token = b"task2_token" @@ -215,7 +220,8 @@ async def run(self): task2.workflow_type.name = "TestWorkflow" task2.started_event_id = 2 task2.attempt = 1 - + task2.history = History() + task2.next_page_token = b"" # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute @@ -352,6 +358,8 @@ async def run(self): task.started_event_id = i + 1 task.attempt = 1 tasks.append(task) + task.history = History() + task.next_page_token = b"" # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) From 4d7b99daac04b9aa0da45ff0f26fe6018d0c623a Mon Sep 17 00:00:00 2001 From: Shijie Sheng Date: Wed, 5 Nov 2025 14:35:53 -0800 Subject: [PATCH 2/2] fix unit test Signed-off-by: Shijie Sheng --- .../test_workflow_engine_integration.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py index 7a86ba7..7a29195 100644 --- a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py +++ b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py @@ -31,20 +31,22 @@ def mock_client(self): return client @pytest.fixture - def workflow_info(self): + def workflow_info(self, mock_client, decision_task): """Create workflow info.""" - data_converter = Mock() - data_converter.from_data = Mock(return_value=["test-input"]) return WorkflowInfo( workflow_type="test_workflow", workflow_domain="test-domain", workflow_id="test-workflow-id", workflow_run_id="test-run-id", workflow_task_list="test-task-list", - workflow_events = [], - data_converter=data_converter, + workflow_events=decision_task.history.events, + data_converter=mock_client.data_converter, ) + @pytest.fixture + def decision_task(self): + return self.create_mock_decision_task() + @pytest.fixture def mock_workflow_definition(self): """Create a mock workflow definition.""" @@ -105,9 +107,10 @@ def create_mock_decision_task( return decision_task @pytest.mark.asyncio - async def test_process_decision_success(self, workflow_engine, mock_client): + async def test_process_decision_success( + self, workflow_engine, mock_client, decision_task + ): """Test successful decision processing.""" - decision_task = self.create_mock_decision_task() # Mock the decision manager to return some decisions with patch.object( @@ -123,9 +126,10 @@ async def test_process_decision_success(self, workflow_engine, mock_client): assert len(result.decisions) == 1 @pytest.mark.asyncio - async def test_process_decision_with_history(self, workflow_engine, mock_client): + async def test_process_decision_with_history( + self, workflow_engine, mock_client, decision_task + ): """Test decision processing with history events.""" - decision_task = self.create_mock_decision_task() # Mock the decision manager with patch.object( @@ -144,14 +148,12 @@ async def test_process_decision_with_history(self, workflow_engine, mock_client) @pytest.mark.asyncio async def test_process_decision_workflow_complete( - self, workflow_engine, mock_client + self, workflow_engine, mock_client, decision_task ): """Test decision processing when workflow is already complete.""" # Mark workflow as complete workflow_engine._is_workflow_complete = True - decision_task = self.create_mock_decision_task() - with patch.object( workflow_engine._decision_manager, "collect_pending_decisions", @@ -165,9 +167,10 @@ async def test_process_decision_workflow_complete( assert len(result.decisions) == 0 @pytest.mark.asyncio - async def test_process_decision_error_handling(self, workflow_engine, mock_client): + async def test_process_decision_error_handling( + self, workflow_engine, mock_client, decision_task + ): """Test decision processing error handling.""" - decision_task = self.create_mock_decision_task() # Mock the decision manager to raise an exception with patch.object( @@ -183,9 +186,10 @@ async def test_process_decision_error_handling(self, workflow_engine, mock_clien assert len(result.decisions) == 0 @pytest.mark.asyncio - async def test_extract_workflow_input_success(self, workflow_engine, mock_client): + async def test_extract_workflow_input_success( + self, workflow_engine: "WorkflowEngine", mock_client, decision_task + ): """Test successful workflow input extraction.""" - decision_task = self.create_mock_decision_task() # Extract workflow input input_data = workflow_engine._extract_workflow_input(decision_task) @@ -241,10 +245,9 @@ async def test_extract_workflow_input_no_started_event( @pytest.mark.asyncio async def test_extract_workflow_input_deserialization_error( - self, workflow_engine, mock_client + self, workflow_engine, mock_client, decision_task ): """Test workflow input extraction with deserialization error.""" - decision_task = self.create_mock_decision_task() # Mock data converter to raise an exception mock_client.data_converter.from_data = AsyncMock( @@ -313,15 +316,14 @@ def test_workflow_engine_initialization( @pytest.mark.asyncio async def test_workflow_engine_without_workflow_definition( - self, mock_client, workflow_info + self, mock_client: Client, workflow_info, decision_task ): """Test WorkflowEngine without workflow definition.""" engine = WorkflowEngine( - info=workflow_info, client=mock_client, workflow_definition=None + info=workflow_info, + workflow_definition=None, ) - decision_task = self.create_mock_decision_task() - with patch.object( engine._decision_manager, "collect_pending_decisions", return_value=[] ): @@ -334,10 +336,9 @@ async def test_workflow_engine_without_workflow_definition( @pytest.mark.asyncio async def test_workflow_engine_workflow_completion( - self, workflow_engine, mock_client + self, workflow_engine, mock_client, decision_task ): """Test workflow completion detection.""" - decision_task = self.create_mock_decision_task() # Create a workflow definition that returns a result (indicating completion) class CompletingWorkflow: @@ -372,10 +373,9 @@ def test_close_event_loop(self, workflow_engine): @pytest.mark.asyncio async def test_process_decision_with_query_results( - self, workflow_engine, mock_client + self, workflow_engine, mock_client, decision_task ): """Test decision processing with query results.""" - decision_task = self.create_mock_decision_task() # Mock the decision manager to return decisions with query results mock_decisions = [Mock()]