From 0ac92cd5d6aa49dea2eda87b4e643fb686b29551 Mon Sep 17 00:00:00 2001 From: Tim Li Date: Thu, 2 Oct 2025 12:28:19 -0700 Subject: [PATCH 1/8] Add workflow invocation methods Signed-off-by: Tim Li --- cadence/client.py | 183 +++++++++++- tests/cadence/test_client_workflow.py | 387 ++++++++++++++++++++++++++ 2 files changed, 569 insertions(+), 1 deletion(-) create mode 100644 tests/cadence/test_client_workflow.py diff --git a/cadence/client.py b/cadence/client.py index 77ec95c..2b4029e 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -1,8 +1,12 @@ import os import socket -from typing import TypedDict, Unpack, Any, cast +import uuid +from dataclasses import dataclass +from datetime import timedelta +from typing import TypedDict, Unpack, Any, cast, Union, Optional, Callable from grpc import ChannelCredentials, Compression +from google.protobuf.duration_pb2 import Duration from cadence._internal.rpc.error import CadenceErrorInterceptor from cadence._internal.rpc.retry import RetryInterceptor @@ -11,10 +15,51 @@ from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel from cadence.api.v1.service_workflow_pb2_grpc import WorkflowAPIStub +from cadence.api.v1.service_workflow_pb2 import StartWorkflowExecutionRequest, StartWorkflowExecutionResponse +from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution +from cadence.api.v1.tasklist_pb2 import TaskList +from cadence.api.v1.workflow_pb2 import WorkflowIdReusePolicy from cadence.data_converter import DataConverter, DefaultDataConverter from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter +@dataclass +class WorkflowRun: + """Represents a workflow run that can be used to get results.""" + execution: WorkflowExecution + client: 'Client' + + @property + def workflow_id(self) -> str: + """Get the workflow ID.""" + return self.execution.workflow_id + + @property + def run_id(self) -> str: + """Get the run ID.""" + return self.execution.run_id + + async def get_result(self, result_type: Optional[type] = None) -> Any: # noqa: ARG002 + """Wait for workflow completion and return result.""" + # TODO: Implement workflow result retrieval + # This would involve polling GetWorkflowExecutionHistory until completion + # and extracting the result from the final event + raise NotImplementedError("get_result not yet implemented") + + +@dataclass +class StartWorkflowOptions: + """Options for starting a workflow execution.""" + workflow_id: Optional[str] = None + task_list: str = "" + execution_start_to_close_timeout: Optional[timedelta] = None + task_start_to_close_timeout: Optional[timedelta] = None + workflow_id_reuse_policy: int = WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE + cron_schedule: Optional[str] = None + memo: Optional[dict[str, Any]] = None + search_attributes: Optional[dict[str, Any]] = None + + class ClientOptions(TypedDict, total=False): domain: str target: str @@ -88,6 +133,142 @@ async def __aenter__(self) -> 'Client': async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.close() + async def _build_start_workflow_request( + self, + workflow: Union[str, Callable], + args: tuple[Any, ...], + options: StartWorkflowOptions + ) -> StartWorkflowExecutionRequest: + """Build a StartWorkflowExecutionRequest from parameters.""" + # Generate workflow ID if not provided + workflow_id = options.workflow_id or str(uuid.uuid4()) + + # Validate required fields + if not options.task_list: + raise ValueError("task_list is required") + + # Determine workflow type name + if isinstance(workflow, str): + workflow_type_name = workflow + else: + # For callable, use function name or __name__ attribute + workflow_type_name = getattr(workflow, '__name__', str(workflow)) + + # Encode input arguments + input_payload = None + if args: + try: + input_payload = await self.data_converter.to_data(list(args)) + except Exception as e: + raise ValueError(f"Failed to encode workflow arguments: {e}") + + # Convert timedelta to protobuf Duration + execution_timeout = None + if options.execution_start_to_close_timeout: + execution_timeout = Duration() + execution_timeout.FromTimedelta(options.execution_start_to_close_timeout) + + task_timeout = None + if options.task_start_to_close_timeout: + task_timeout = Duration() + task_timeout.FromTimedelta(options.task_start_to_close_timeout) + + # Build the request + request = StartWorkflowExecutionRequest( + domain=self.domain, + workflow_id=workflow_id, + workflow_type=WorkflowType(name=workflow_type_name), + task_list=TaskList(name=options.task_list), + identity=self.identity, + request_id=str(uuid.uuid4()) + ) + + # Set workflow_id_reuse_policy separately to avoid type issues + request.workflow_id_reuse_policy = options.workflow_id_reuse_policy # type: ignore[assignment] + + # Set optional fields + if input_payload: + request.input.CopyFrom(input_payload) + if execution_timeout: + request.execution_start_to_close_timeout.CopyFrom(execution_timeout) + if task_timeout: + request.task_start_to_close_timeout.CopyFrom(task_timeout) + if options.cron_schedule: + request.cron_schedule = options.cron_schedule + + return request + + async def start_workflow( + self, + workflow: Union[str, Callable], + *args, + **options_kwargs + ) -> WorkflowExecution: + """ + Start a workflow execution asynchronously. + + Args: + workflow: Workflow function or workflow type name string + *args: Arguments to pass to the workflow + **options_kwargs: StartWorkflowOptions as keyword arguments + + Returns: + WorkflowExecution with workflow_id and run_id + + Raises: + ValueError: If required parameters are missing or invalid + Exception: If the gRPC call fails + """ + # Convert kwargs to StartWorkflowOptions + options = StartWorkflowOptions(**options_kwargs) + + # Build the gRPC request + request = await self._build_start_workflow_request(workflow, args, options) + + # Execute the gRPC call + try: + response: StartWorkflowExecutionResponse = await self.workflow_stub.StartWorkflowExecution(request) + + # Emit metrics if available + if self.metrics_emitter: + # TODO: Add workflow start metrics similar to Go client + pass + + execution = WorkflowExecution() + execution.workflow_id = request.workflow_id + execution.run_id = response.run_id + return execution + except Exception as e: + raise Exception(f"Failed to start workflow: {e}") from e + + async def execute_workflow( + self, + workflow: Union[str, Callable], + *args, + **options_kwargs + ) -> WorkflowRun: + """ + Start a workflow execution and return a handle to get the result. + + Args: + workflow: Workflow function or workflow type name string + *args: Arguments to pass to the workflow + **options_kwargs: StartWorkflowOptions as keyword arguments + + Returns: + WorkflowRun that can be used to get the workflow result + + Raises: + ValueError: If required parameters are missing or invalid + Exception: If the gRPC call fails + """ + execution = await self.start_workflow(workflow, *args, **options_kwargs) + + return WorkflowRun( + execution=execution, + client=self + ) + def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: if "target" not in options: raise ValueError("target must be specified") diff --git a/tests/cadence/test_client_workflow.py b/tests/cadence/test_client_workflow.py new file mode 100644 index 0000000..40e5d54 --- /dev/null +++ b/tests/cadence/test_client_workflow.py @@ -0,0 +1,387 @@ +import pytest +import uuid +from datetime import timedelta +from unittest.mock import AsyncMock, Mock, PropertyMock + +from cadence.api.v1.common_pb2 import WorkflowExecution +from cadence.api.v1.service_workflow_pb2 import StartWorkflowExecutionRequest, StartWorkflowExecutionResponse +from cadence.api.v1.workflow_pb2 import WorkflowIdReusePolicy +from cadence.client import Client, StartWorkflowOptions, WorkflowRun +from cadence.data_converter import DefaultDataConverter + + +@pytest.fixture +def mock_client(): + """Create a mock client for testing.""" + client = Mock(spec=Client) + type(client).domain = PropertyMock(return_value="test-domain") + type(client).identity = PropertyMock(return_value="test-identity") + type(client).data_converter = PropertyMock(return_value=DefaultDataConverter()) + type(client).metrics_emitter = PropertyMock(return_value=None) + + # Mock the workflow stub + workflow_stub = Mock() + type(client).workflow_stub = PropertyMock(return_value=workflow_stub) + + return client + + +class TestStartWorkflowOptions: + """Test StartWorkflowOptions dataclass.""" + + def test_default_values(self): + """Test default values for StartWorkflowOptions.""" + options = StartWorkflowOptions() + assert options.workflow_id is None + assert options.task_list == "" + assert options.execution_start_to_close_timeout is None + assert options.task_start_to_close_timeout is None + assert options.workflow_id_reuse_policy == WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE + assert options.cron_schedule is None + assert options.memo is None + assert options.search_attributes is None + + def test_custom_values(self): + """Test setting custom values for StartWorkflowOptions.""" + options = StartWorkflowOptions( + workflow_id="custom-id", + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=30), + task_start_to_close_timeout=timedelta(seconds=10), + workflow_id_reuse_policy=WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE, + cron_schedule="0 * * * *", + memo={"key": "value"}, + search_attributes={"attr": "value"} + ) + + assert options.workflow_id == "custom-id" + assert options.task_list == "test-task-list" + assert options.execution_start_to_close_timeout == timedelta(minutes=30) + assert options.task_start_to_close_timeout == timedelta(seconds=10) + assert options.workflow_id_reuse_policy == WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE + assert options.cron_schedule == "0 * * * *" + assert options.memo == {"key": "value"} + assert options.search_attributes == {"attr": "value"} + + +class TestWorkflowRun: + """Test WorkflowRun class.""" + + def test_properties(self, mock_client): + """Test WorkflowRun properties.""" + execution = WorkflowExecution() + execution.workflow_id = "test-workflow-id" + execution.run_id = "test-run-id" + + workflow_run = WorkflowRun(execution=execution, client=mock_client) + + assert workflow_run.workflow_id == "test-workflow-id" + assert workflow_run.run_id == "test-run-id" + assert workflow_run.client is mock_client + + @pytest.mark.asyncio + async def test_get_result_not_implemented(self, mock_client): + """Test that get_result raises NotImplementedError.""" + execution = WorkflowExecution() + execution.workflow_id = "test-workflow-id" + execution.run_id = "test-run-id" + + workflow_run = WorkflowRun(execution=execution, client=mock_client) + + with pytest.raises(NotImplementedError, match="get_result not yet implemented"): + await workflow_run.get_result() + + +class TestClientBuildStartWorkflowRequest: + """Test Client._build_start_workflow_request method.""" + + @pytest.mark.asyncio + async def test_build_request_with_string_workflow(self, mock_client): + """Test building request with string workflow name.""" + # Create real client instance to test the method + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions( + workflow_id="test-workflow-id", + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=30), + task_start_to_close_timeout=timedelta(seconds=10) + ) + + request = await client._build_start_workflow_request("TestWorkflow", ("arg1", "arg2"), options) + + assert isinstance(request, StartWorkflowExecutionRequest) + assert request.domain == "test-domain" + assert request.workflow_id == "test-workflow-id" + assert request.workflow_type.name == "TestWorkflow" + assert request.task_list.name == "test-task-list" + assert request.identity == client.identity + assert request.workflow_id_reuse_policy == WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE + assert request.request_id != "" # Should be a UUID + + # Verify UUID format + uuid.UUID(request.request_id) # This will raise if not valid UUID + + @pytest.mark.asyncio + async def test_build_request_with_callable_workflow(self, mock_client): + """Test building request with callable workflow.""" + def test_workflow(): + pass + + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions( + task_list="test-task-list" + ) + + request = await client._build_start_workflow_request(test_workflow, (), options) + + assert request.workflow_type.name == "test_workflow" + + @pytest.mark.asyncio + async def test_build_request_generates_workflow_id(self, mock_client): + """Test that workflow_id is generated when not provided.""" + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions(task_list="test-task-list") + + request = await client._build_start_workflow_request("TestWorkflow", (), options) + + assert request.workflow_id != "" + # Verify it's a valid UUID + uuid.UUID(request.workflow_id) + + @pytest.mark.asyncio + async def test_build_request_missing_task_list(self, mock_client): + """Test that missing task_list raises ValueError.""" + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions() # No task_list + + with pytest.raises(ValueError, match="task_list is required"): + await client._build_start_workflow_request("TestWorkflow", (), options) + + @pytest.mark.asyncio + async def test_build_request_with_input_args(self, mock_client): + """Test building request with input arguments.""" + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions(task_list="test-task-list") + + request = await client._build_start_workflow_request("TestWorkflow", ("arg1", 42, {"key": "value"}), options) + + # Should have input payload + assert request.HasField("input") + assert len(request.input.data) > 0 + + @pytest.mark.asyncio + async def test_build_request_with_timeouts(self, mock_client): + """Test building request with timeout settings.""" + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=30), + task_start_to_close_timeout=timedelta(seconds=10) + ) + + request = await client._build_start_workflow_request("TestWorkflow", (), options) + + assert request.HasField("execution_start_to_close_timeout") + assert request.HasField("task_start_to_close_timeout") + + # Check timeout values (30 minutes = 1800 seconds) + assert request.execution_start_to_close_timeout.seconds == 1800 + assert request.task_start_to_close_timeout.seconds == 10 + + @pytest.mark.asyncio + async def test_build_request_with_cron_schedule(self, mock_client): + """Test building request with cron schedule.""" + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions( + task_list="test-task-list", + cron_schedule="0 * * * *" + ) + + request = await client._build_start_workflow_request("TestWorkflow", (), options) + + assert request.cron_schedule == "0 * * * *" + + +class TestClientStartWorkflow: + """Test Client.start_workflow method.""" + + @pytest.mark.asyncio + async def test_start_workflow_success(self, mock_client): + """Test successful workflow start.""" + # Setup mock response + response = StartWorkflowExecutionResponse() + response.run_id = "test-run-id" + + mock_client.workflow_stub.StartWorkflowExecution = AsyncMock(return_value=response) + + # Create a real client but replace the workflow_stub + client = Client(domain="test-domain", target="localhost:7933") + client._workflow_stub = mock_client.workflow_stub + + # Mock the internal method to avoid full request building + async def mock_build_request(workflow, args, options): + request = StartWorkflowExecutionRequest() + request.workflow_id = "test-workflow-id" + request.domain = "test-domain" + return request + + client._build_start_workflow_request = AsyncMock(side_effect=mock_build_request) + + execution = await client.start_workflow( + "TestWorkflow", + "arg1", "arg2", + task_list="test-task-list", + workflow_id="test-workflow-id" + ) + + assert isinstance(execution, WorkflowExecution) + assert execution.workflow_id == "test-workflow-id" + assert execution.run_id == "test-run-id" + + # Verify the gRPC call was made + mock_client.workflow_stub.StartWorkflowExecution.assert_called_once() + + @pytest.mark.asyncio + async def test_start_workflow_grpc_error(self, mock_client): + """Test workflow start with gRPC error.""" + # Setup mock to raise exception + mock_client.workflow_stub.StartWorkflowExecution = AsyncMock(side_effect=Exception("gRPC error")) + + client = Client(domain="test-domain", target="localhost:7933") + client._workflow_stub = mock_client.workflow_stub + + # Mock the internal method + client._build_start_workflow_request = AsyncMock(return_value=StartWorkflowExecutionRequest()) + + with pytest.raises(Exception, match="Failed to start workflow: gRPC error"): + await client.start_workflow( + "TestWorkflow", + task_list="test-task-list" + ) + + @pytest.mark.asyncio + async def test_start_workflow_with_kwargs(self, mock_client): + """Test start_workflow with options as kwargs.""" + response = StartWorkflowExecutionResponse() + response.run_id = "test-run-id" + + mock_client.workflow_stub.StartWorkflowExecution = AsyncMock(return_value=response) + + client = Client(domain="test-domain", target="localhost:7933") + client._workflow_stub = mock_client.workflow_stub + + # Mock the internal method to capture options + captured_options = None + async def mock_build_request(workflow, args, options): + nonlocal captured_options + captured_options = options + request = StartWorkflowExecutionRequest() + request.workflow_id = "test-workflow-id" + return request + + client._build_start_workflow_request = AsyncMock(side_effect=mock_build_request) + + await client.start_workflow( + "TestWorkflow", + "arg1", + task_list="test-task-list", + workflow_id="custom-id", + execution_start_to_close_timeout=timedelta(minutes=30) + ) + + # Verify options were properly constructed + assert captured_options.task_list == "test-task-list" + assert captured_options.workflow_id == "custom-id" + assert captured_options.execution_start_to_close_timeout == timedelta(minutes=30) + + +class TestClientExecuteWorkflow: + """Test Client.execute_workflow method.""" + + @pytest.mark.asyncio + async def test_execute_workflow_success(self, mock_client): + """Test successful workflow execution.""" + # Mock start_workflow to return execution + execution = WorkflowExecution() + execution.workflow_id = "test-workflow-id" + execution.run_id = "test-run-id" + + client = Client(domain="test-domain", target="localhost:7933") + client.start_workflow = AsyncMock(return_value=execution) + + workflow_run = await client.execute_workflow( + "TestWorkflow", + "arg1", "arg2", + task_list="test-task-list" + ) + + assert isinstance(workflow_run, WorkflowRun) + assert workflow_run.execution is execution + assert workflow_run.client is client + assert workflow_run.workflow_id == "test-workflow-id" + assert workflow_run.run_id == "test-run-id" + + # Verify start_workflow was called with correct arguments + client.start_workflow.assert_called_once_with( + "TestWorkflow", + "arg1", "arg2", + task_list="test-task-list" + ) + + @pytest.mark.asyncio + async def test_execute_workflow_propagates_error(self, mock_client): + """Test that execute_workflow propagates errors from start_workflow.""" + client = Client(domain="test-domain", target="localhost:7933") + client.start_workflow = AsyncMock(side_effect=ValueError("Invalid task_list")) + + with pytest.raises(ValueError, match="Invalid task_list"): + await client.execute_workflow( + "TestWorkflow", + task_list="" + ) + + +@pytest.mark.asyncio +async def test_integration_workflow_invocation(): + """Integration test for workflow invocation flow.""" + # This test verifies the complete flow works together + response = StartWorkflowExecutionResponse() + response.run_id = "integration-run-id" + + # Create client with mocked gRPC stub + client = Client(domain="test-domain", target="localhost:7933") + client._workflow_stub = Mock() + client._workflow_stub.StartWorkflowExecution = AsyncMock(return_value=response) + + # Test the complete flow + workflow_run = await client.execute_workflow( + "IntegrationTestWorkflow", + "test-arg", + 42, + {"data": "value"}, + task_list="integration-task-list", + workflow_id="integration-workflow-id", + execution_start_to_close_timeout=timedelta(minutes=10) + ) + + # Verify result + assert workflow_run.workflow_id == "integration-workflow-id" + assert workflow_run.run_id == "integration-run-id" + + # Verify the gRPC call was made with proper request + client._workflow_stub.StartWorkflowExecution.assert_called_once() + request = client._workflow_stub.StartWorkflowExecution.call_args[0][0] + + assert request.domain == "test-domain" + assert request.workflow_id == "integration-workflow-id" + assert request.workflow_type.name == "IntegrationTestWorkflow" + assert request.task_list.name == "integration-task-list" + assert request.HasField("input") # Should have encoded input + assert request.HasField("execution_start_to_close_timeout") \ No newline at end of file From c357057ceb062a82e9626c85c6a70e77ad25ee58 Mon Sep 17 00:00:00 2001 From: Tim Li Date: Fri, 3 Oct 2025 14:34:58 -0700 Subject: [PATCH 2/8] minor change Signed-off-by: Tim Li --- cadence/client.py | 36 +++------------------ tests/cadence/test_client_workflow.py | 45 ++++++--------------------- 2 files changed, 14 insertions(+), 67 deletions(-) diff --git a/cadence/client.py b/cadence/client.py index 2b4029e..51a6f68 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -23,29 +23,6 @@ from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter -@dataclass -class WorkflowRun: - """Represents a workflow run that can be used to get results.""" - execution: WorkflowExecution - client: 'Client' - - @property - def workflow_id(self) -> str: - """Get the workflow ID.""" - return self.execution.workflow_id - - @property - def run_id(self) -> str: - """Get the run ID.""" - return self.execution.run_id - - async def get_result(self, result_type: Optional[type] = None) -> Any: # noqa: ARG002 - """Wait for workflow completion and return result.""" - # TODO: Implement workflow result retrieval - # This would involve polling GetWorkflowExecutionHistory until completion - # and extracting the result from the final event - raise NotImplementedError("get_result not yet implemented") - @dataclass class StartWorkflowOptions: @@ -246,9 +223,9 @@ async def execute_workflow( workflow: Union[str, Callable], *args, **options_kwargs - ) -> WorkflowRun: + ) -> WorkflowExecution: """ - Start a workflow execution and return a handle to get the result. + Start a workflow execution and return the execution handle. Args: workflow: Workflow function or workflow type name string @@ -256,18 +233,15 @@ async def execute_workflow( **options_kwargs: StartWorkflowOptions as keyword arguments Returns: - WorkflowRun that can be used to get the workflow result + WorkflowExecution that contains workflow_id and run_id Raises: ValueError: If required parameters are missing or invalid Exception: If the gRPC call fails """ - execution = await self.start_workflow(workflow, *args, **options_kwargs) + return await self.start_workflow(workflow, *args, **options_kwargs) + - return WorkflowRun( - execution=execution, - client=self - ) def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: if "target" not in options: diff --git a/tests/cadence/test_client_workflow.py b/tests/cadence/test_client_workflow.py index 40e5d54..a6087ef 100644 --- a/tests/cadence/test_client_workflow.py +++ b/tests/cadence/test_client_workflow.py @@ -6,7 +6,7 @@ from cadence.api.v1.common_pb2 import WorkflowExecution from cadence.api.v1.service_workflow_pb2 import StartWorkflowExecutionRequest, StartWorkflowExecutionResponse from cadence.api.v1.workflow_pb2 import WorkflowIdReusePolicy -from cadence.client import Client, StartWorkflowOptions, WorkflowRun +from cadence.client import Client, StartWorkflowOptions from cadence.data_converter import DefaultDataConverter @@ -64,32 +64,6 @@ def test_custom_values(self): assert options.search_attributes == {"attr": "value"} -class TestWorkflowRun: - """Test WorkflowRun class.""" - - def test_properties(self, mock_client): - """Test WorkflowRun properties.""" - execution = WorkflowExecution() - execution.workflow_id = "test-workflow-id" - execution.run_id = "test-run-id" - - workflow_run = WorkflowRun(execution=execution, client=mock_client) - - assert workflow_run.workflow_id == "test-workflow-id" - assert workflow_run.run_id == "test-run-id" - assert workflow_run.client is mock_client - - @pytest.mark.asyncio - async def test_get_result_not_implemented(self, mock_client): - """Test that get_result raises NotImplementedError.""" - execution = WorkflowExecution() - execution.workflow_id = "test-workflow-id" - execution.run_id = "test-run-id" - - workflow_run = WorkflowRun(execution=execution, client=mock_client) - - with pytest.raises(NotImplementedError, match="get_result not yet implemented"): - await workflow_run.get_result() class TestClientBuildStartWorkflowRequest: @@ -316,17 +290,16 @@ async def test_execute_workflow_success(self, mock_client): client = Client(domain="test-domain", target="localhost:7933") client.start_workflow = AsyncMock(return_value=execution) - workflow_run = await client.execute_workflow( + result_execution = await client.execute_workflow( "TestWorkflow", "arg1", "arg2", task_list="test-task-list" ) - assert isinstance(workflow_run, WorkflowRun) - assert workflow_run.execution is execution - assert workflow_run.client is client - assert workflow_run.workflow_id == "test-workflow-id" - assert workflow_run.run_id == "test-run-id" + assert isinstance(result_execution, WorkflowExecution) + assert result_execution is execution + assert result_execution.workflow_id == "test-workflow-id" + assert result_execution.run_id == "test-run-id" # Verify start_workflow was called with correct arguments client.start_workflow.assert_called_once_with( @@ -361,7 +334,7 @@ async def test_integration_workflow_invocation(): client._workflow_stub.StartWorkflowExecution = AsyncMock(return_value=response) # Test the complete flow - workflow_run = await client.execute_workflow( + execution = await client.execute_workflow( "IntegrationTestWorkflow", "test-arg", 42, @@ -372,8 +345,8 @@ async def test_integration_workflow_invocation(): ) # Verify result - assert workflow_run.workflow_id == "integration-workflow-id" - assert workflow_run.run_id == "integration-run-id" + assert execution.workflow_id == "integration-workflow-id" + assert execution.run_id == "integration-run-id" # Verify the gRPC call was made with proper request client._workflow_stub.StartWorkflowExecution.assert_called_once() From 188bb6d2d96c0399958382a442afbe8debb0375a Mon Sep 17 00:00:00 2001 From: Tim Li Date: Fri, 3 Oct 2025 14:50:09 -0700 Subject: [PATCH 3/8] remove duplicate logic Signed-off-by: Tim Li --- cadence/client.py | 22 ------------- tests/cadence/test_client_workflow.py | 45 ++++++++++++++------------- 2 files changed, 23 insertions(+), 44 deletions(-) diff --git a/cadence/client.py b/cadence/client.py index 51a6f68..0b52da4 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -218,28 +218,6 @@ async def start_workflow( except Exception as e: raise Exception(f"Failed to start workflow: {e}") from e - async def execute_workflow( - self, - workflow: Union[str, Callable], - *args, - **options_kwargs - ) -> WorkflowExecution: - """ - Start a workflow execution and return the execution handle. - - Args: - workflow: Workflow function or workflow type name string - *args: Arguments to pass to the workflow - **options_kwargs: StartWorkflowOptions as keyword arguments - - Returns: - WorkflowExecution that contains workflow_id and run_id - - Raises: - ValueError: If required parameters are missing or invalid - Exception: If the gRPC call fails - """ - return await self.start_workflow(workflow, *args, **options_kwargs) diff --git a/tests/cadence/test_client_workflow.py b/tests/cadence/test_client_workflow.py index a6087ef..ca50f99 100644 --- a/tests/cadence/test_client_workflow.py +++ b/tests/cadence/test_client_workflow.py @@ -276,48 +276,49 @@ async def mock_build_request(workflow, args, options): assert captured_options.execution_start_to_close_timeout == timedelta(minutes=30) -class TestClientExecuteWorkflow: - """Test Client.execute_workflow method.""" +class TestClientStartWorkflow: + """Test Client.start_workflow method.""" @pytest.mark.asyncio - async def test_execute_workflow_success(self, mock_client): - """Test successful workflow execution.""" - # Mock start_workflow to return execution + async def test_start_workflow_success(self, mock_client): + """Test successful workflow start.""" + # Mock the gRPC stub execution = WorkflowExecution() execution.workflow_id = "test-workflow-id" execution.run_id = "test-run-id" + response = StartWorkflowExecutionResponse() + response.run_id = "test-run-id" + client = Client(domain="test-domain", target="localhost:7933") - client.start_workflow = AsyncMock(return_value=execution) + client._workflow_stub = Mock() + client._workflow_stub.StartWorkflowExecution = AsyncMock(return_value=response) - result_execution = await client.execute_workflow( + result_execution = await client.start_workflow( "TestWorkflow", "arg1", "arg2", - task_list="test-task-list" + task_list="test-task-list", + workflow_id="test-workflow-id" ) assert isinstance(result_execution, WorkflowExecution) - assert result_execution is execution assert result_execution.workflow_id == "test-workflow-id" assert result_execution.run_id == "test-run-id" - # Verify start_workflow was called with correct arguments - client.start_workflow.assert_called_once_with( - "TestWorkflow", - "arg1", "arg2", - task_list="test-task-list" - ) + # Verify gRPC call was made + client._workflow_stub.StartWorkflowExecution.assert_called_once() @pytest.mark.asyncio - async def test_execute_workflow_propagates_error(self, mock_client): - """Test that execute_workflow propagates errors from start_workflow.""" + async def test_start_workflow_propagates_error(self, mock_client): + """Test that start_workflow propagates gRPC errors.""" client = Client(domain="test-domain", target="localhost:7933") - client.start_workflow = AsyncMock(side_effect=ValueError("Invalid task_list")) + client._workflow_stub = Mock() + client._workflow_stub.StartWorkflowExecution = AsyncMock(side_effect=ValueError("gRPC error")) - with pytest.raises(ValueError, match="Invalid task_list"): - await client.execute_workflow( + with pytest.raises(Exception, match="Failed to start workflow"): + await client.start_workflow( "TestWorkflow", - task_list="" + task_list="valid-task-list" ) @@ -334,7 +335,7 @@ async def test_integration_workflow_invocation(): client._workflow_stub.StartWorkflowExecution = AsyncMock(return_value=response) # Test the complete flow - execution = await client.execute_workflow( + execution = await client.start_workflow( "IntegrationTestWorkflow", "test-arg", 42, From f7df61fb581fd1f30808dbfccdbcdaaea50dabb1 Mon Sep 17 00:00:00 2001 From: Tim Li Date: Fri, 3 Oct 2025 15:11:54 -0700 Subject: [PATCH 4/8] fix linter Signed-off-by: Tim Li --- tests/cadence/test_client_workflow.py | 44 --------------------------- 1 file changed, 44 deletions(-) diff --git a/tests/cadence/test_client_workflow.py b/tests/cadence/test_client_workflow.py index ca50f99..73cc6d9 100644 --- a/tests/cadence/test_client_workflow.py +++ b/tests/cadence/test_client_workflow.py @@ -276,50 +276,6 @@ async def mock_build_request(workflow, args, options): assert captured_options.execution_start_to_close_timeout == timedelta(minutes=30) -class TestClientStartWorkflow: - """Test Client.start_workflow method.""" - - @pytest.mark.asyncio - async def test_start_workflow_success(self, mock_client): - """Test successful workflow start.""" - # Mock the gRPC stub - execution = WorkflowExecution() - execution.workflow_id = "test-workflow-id" - execution.run_id = "test-run-id" - - response = StartWorkflowExecutionResponse() - response.run_id = "test-run-id" - - client = Client(domain="test-domain", target="localhost:7933") - client._workflow_stub = Mock() - client._workflow_stub.StartWorkflowExecution = AsyncMock(return_value=response) - - result_execution = await client.start_workflow( - "TestWorkflow", - "arg1", "arg2", - task_list="test-task-list", - workflow_id="test-workflow-id" - ) - - assert isinstance(result_execution, WorkflowExecution) - assert result_execution.workflow_id == "test-workflow-id" - assert result_execution.run_id == "test-run-id" - - # Verify gRPC call was made - client._workflow_stub.StartWorkflowExecution.assert_called_once() - - @pytest.mark.asyncio - async def test_start_workflow_propagates_error(self, mock_client): - """Test that start_workflow propagates gRPC errors.""" - client = Client(domain="test-domain", target="localhost:7933") - client._workflow_stub = Mock() - client._workflow_stub.StartWorkflowExecution = AsyncMock(side_effect=ValueError("gRPC error")) - - with pytest.raises(Exception, match="Failed to start workflow"): - await client.start_workflow( - "TestWorkflow", - task_list="valid-task-list" - ) @pytest.mark.asyncio From 216a5e8554a486f37ad31e85b627eb5bdbcf6393 Mon Sep 17 00:00:00 2001 From: Tim Li Date: Tue, 7 Oct 2025 09:44:11 -0700 Subject: [PATCH 5/8] removed unimplemented field Signed-off-by: Tim Li --- cadence/client.py | 2 -- tests/cadence/test_client_workflow.py | 8 +------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/cadence/client.py b/cadence/client.py index 0b52da4..f6f7253 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -33,8 +33,6 @@ class StartWorkflowOptions: task_start_to_close_timeout: Optional[timedelta] = None workflow_id_reuse_policy: int = WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE cron_schedule: Optional[str] = None - memo: Optional[dict[str, Any]] = None - search_attributes: Optional[dict[str, Any]] = None class ClientOptions(TypedDict, total=False): diff --git a/tests/cadence/test_client_workflow.py b/tests/cadence/test_client_workflow.py index 73cc6d9..08529fe 100644 --- a/tests/cadence/test_client_workflow.py +++ b/tests/cadence/test_client_workflow.py @@ -38,8 +38,6 @@ def test_default_values(self): assert options.task_start_to_close_timeout is None assert options.workflow_id_reuse_policy == WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE assert options.cron_schedule is None - assert options.memo is None - assert options.search_attributes is None def test_custom_values(self): """Test setting custom values for StartWorkflowOptions.""" @@ -49,9 +47,7 @@ def test_custom_values(self): execution_start_to_close_timeout=timedelta(minutes=30), task_start_to_close_timeout=timedelta(seconds=10), workflow_id_reuse_policy=WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE, - cron_schedule="0 * * * *", - memo={"key": "value"}, - search_attributes={"attr": "value"} + cron_schedule="0 * * * *" ) assert options.workflow_id == "custom-id" @@ -60,8 +56,6 @@ def test_custom_values(self): assert options.task_start_to_close_timeout == timedelta(seconds=10) assert options.workflow_id_reuse_policy == WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE assert options.cron_schedule == "0 * * * *" - assert options.memo == {"key": "value"} - assert options.search_attributes == {"attr": "value"} From d9ed5bbb4ad67e3a80429986a65811cd565d8303 Mon Sep 17 00:00:00 2001 From: Tim Li Date: Tue, 7 Oct 2025 11:26:50 -0700 Subject: [PATCH 6/8] respond to comments Signed-off-by: Tim Li --- cadence/client.py | 20 ++++--- tests/cadence/test_client_workflow.py | 77 ++++++++++++++++++++------- 2 files changed, 66 insertions(+), 31 deletions(-) diff --git a/cadence/client.py b/cadence/client.py index f6f7253..31b76f2 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -18,7 +18,6 @@ from cadence.api.v1.service_workflow_pb2 import StartWorkflowExecutionRequest, StartWorkflowExecutionResponse from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution from cadence.api.v1.tasklist_pb2 import TaskList -from cadence.api.v1.workflow_pb2 import WorkflowIdReusePolicy from cadence.data_converter import DataConverter, DefaultDataConverter from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter @@ -27,13 +26,19 @@ @dataclass class StartWorkflowOptions: """Options for starting a workflow execution.""" - workflow_id: Optional[str] = None - task_list: str = "" + task_list: str execution_start_to_close_timeout: Optional[timedelta] = None task_start_to_close_timeout: Optional[timedelta] = None - workflow_id_reuse_policy: int = WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE + workflow_id: Optional[str] = None cron_schedule: Optional[str] = None + def __post_init__(self): + """Validate required fields after initialization.""" + if not self.task_list: + raise ValueError("task_list is required") + if not self.execution_start_to_close_timeout and not self.task_start_to_close_timeout: + raise ValueError("either execution_start_to_close_timeout or task_start_to_close_timeout is required") + class ClientOptions(TypedDict, total=False): domain: str @@ -118,10 +123,6 @@ async def _build_start_workflow_request( # Generate workflow ID if not provided workflow_id = options.workflow_id or str(uuid.uuid4()) - # Validate required fields - if not options.task_list: - raise ValueError("task_list is required") - # Determine workflow type name if isinstance(workflow, str): workflow_type_name = workflow @@ -158,9 +159,6 @@ async def _build_start_workflow_request( request_id=str(uuid.uuid4()) ) - # Set workflow_id_reuse_policy separately to avoid type issues - request.workflow_id_reuse_policy = options.workflow_id_reuse_policy # type: ignore[assignment] - # Set optional fields if input_payload: request.input.CopyFrom(input_payload) diff --git a/tests/cadence/test_client_workflow.py b/tests/cadence/test_client_workflow.py index 08529fe..5ae2ea5 100644 --- a/tests/cadence/test_client_workflow.py +++ b/tests/cadence/test_client_workflow.py @@ -5,7 +5,6 @@ from cadence.api.v1.common_pb2 import WorkflowExecution from cadence.api.v1.service_workflow_pb2 import StartWorkflowExecutionRequest, StartWorkflowExecutionResponse -from cadence.api.v1.workflow_pb2 import WorkflowIdReusePolicy from cadence.client import Client, StartWorkflowOptions from cadence.data_converter import DefaultDataConverter @@ -31,12 +30,15 @@ class TestStartWorkflowOptions: def test_default_values(self): """Test default values for StartWorkflowOptions.""" - options = StartWorkflowOptions() + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30) + ) assert options.workflow_id is None - assert options.task_list == "" - assert options.execution_start_to_close_timeout is None - assert options.task_start_to_close_timeout is None - assert options.workflow_id_reuse_policy == WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE + assert options.task_list == "test-task-list" + assert options.execution_start_to_close_timeout == timedelta(minutes=10) + assert options.task_start_to_close_timeout == timedelta(seconds=30) assert options.cron_schedule is None def test_custom_values(self): @@ -46,7 +48,6 @@ def test_custom_values(self): task_list="test-task-list", execution_start_to_close_timeout=timedelta(minutes=30), task_start_to_close_timeout=timedelta(seconds=10), - workflow_id_reuse_policy=WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE, cron_schedule="0 * * * *" ) @@ -54,7 +55,6 @@ def test_custom_values(self): assert options.task_list == "test-task-list" assert options.execution_start_to_close_timeout == timedelta(minutes=30) assert options.task_start_to_close_timeout == timedelta(seconds=10) - assert options.workflow_id_reuse_policy == WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE assert options.cron_schedule == "0 * * * *" @@ -84,7 +84,7 @@ async def test_build_request_with_string_workflow(self, mock_client): assert request.workflow_type.name == "TestWorkflow" assert request.task_list.name == "test-task-list" assert request.identity == client.identity - assert request.workflow_id_reuse_policy == WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE + assert request.workflow_id_reuse_policy == 0 # Default protobuf value when not set assert request.request_id != "" # Should be a UUID # Verify UUID format @@ -99,7 +99,9 @@ def test_workflow(): client = Client(domain="test-domain", target="localhost:7933") options = StartWorkflowOptions( - task_list="test-task-list" + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30) ) request = await client._build_start_workflow_request(test_workflow, (), options) @@ -111,7 +113,11 @@ async def test_build_request_generates_workflow_id(self, mock_client): """Test that workflow_id is generated when not provided.""" client = Client(domain="test-domain", target="localhost:7933") - options = StartWorkflowOptions(task_list="test-task-list") + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30) + ) request = await client._build_start_workflow_request("TestWorkflow", (), options) @@ -122,19 +128,42 @@ async def test_build_request_generates_workflow_id(self, mock_client): @pytest.mark.asyncio async def test_build_request_missing_task_list(self, mock_client): """Test that missing task_list raises ValueError.""" - client = Client(domain="test-domain", target="localhost:7933") + with pytest.raises(TypeError): # task_list is now required positional argument + StartWorkflowOptions() # No task_list - options = StartWorkflowOptions() # No task_list + def test_missing_timeout_raises_error(self): + """Test that missing both timeouts raises ValueError.""" + with pytest.raises(ValueError, match="either execution_start_to_close_timeout or task_start_to_close_timeout is required"): + StartWorkflowOptions(task_list="test-task-list") - with pytest.raises(ValueError, match="task_list is required"): - await client._build_start_workflow_request("TestWorkflow", (), options) + def test_only_execution_timeout(self): + """Test that only execution_start_to_close_timeout is valid.""" + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10) + ) + assert options.execution_start_to_close_timeout == timedelta(minutes=10) + assert options.task_start_to_close_timeout is None + + def test_only_task_timeout(self): + """Test that only task_start_to_close_timeout is valid.""" + options = StartWorkflowOptions( + task_list="test-task-list", + task_start_to_close_timeout=timedelta(seconds=30) + ) + assert options.execution_start_to_close_timeout is None + assert options.task_start_to_close_timeout == timedelta(seconds=30) @pytest.mark.asyncio async def test_build_request_with_input_args(self, mock_client): """Test building request with input arguments.""" client = Client(domain="test-domain", target="localhost:7933") - options = StartWorkflowOptions(task_list="test-task-list") + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30) + ) request = await client._build_start_workflow_request("TestWorkflow", ("arg1", 42, {"key": "value"}), options) @@ -169,6 +198,8 @@ async def test_build_request_with_cron_schedule(self, mock_client): options = StartWorkflowOptions( task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30), cron_schedule="0 * * * *" ) @@ -206,7 +237,9 @@ async def mock_build_request(workflow, args, options): "TestWorkflow", "arg1", "arg2", task_list="test-task-list", - workflow_id="test-workflow-id" + workflow_id="test-workflow-id", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30) ) assert isinstance(execution, WorkflowExecution) @@ -231,7 +264,9 @@ async def test_start_workflow_grpc_error(self, mock_client): with pytest.raises(Exception, match="Failed to start workflow: gRPC error"): await client.start_workflow( "TestWorkflow", - task_list="test-task-list" + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30) ) @pytest.mark.asyncio @@ -261,7 +296,8 @@ async def mock_build_request(workflow, args, options): "arg1", task_list="test-task-list", workflow_id="custom-id", - execution_start_to_close_timeout=timedelta(minutes=30) + execution_start_to_close_timeout=timedelta(minutes=30), + task_start_to_close_timeout=timedelta(seconds=30) ) # Verify options were properly constructed @@ -292,7 +328,8 @@ async def test_integration_workflow_invocation(): {"data": "value"}, task_list="integration-task-list", workflow_id="integration-workflow-id", - execution_start_to_close_timeout=timedelta(minutes=10) + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30) ) # Verify result From 0c82fb0d8d3ba9327ffbc0fad3050cefa1f5cb67 Mon Sep 17 00:00:00 2001 From: Tim Li Date: Tue, 7 Oct 2025 15:58:29 -0700 Subject: [PATCH 7/8] respond to comment Signed-off-by: Tim Li --- cadence/client.py | 117 +++++++++------ tests/cadence/test_client_workflow.py | 198 ++++++++++++++++++-------- 2 files changed, 210 insertions(+), 105 deletions(-) diff --git a/cadence/client.py b/cadence/client.py index 31b76f2..e36ffd6 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -15,29 +15,45 @@ from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel from cadence.api.v1.service_workflow_pb2_grpc import WorkflowAPIStub -from cadence.api.v1.service_workflow_pb2 import StartWorkflowExecutionRequest, StartWorkflowExecutionResponse +from cadence.api.v1.service_workflow_pb2 import ( + StartWorkflowExecutionRequest, + StartWorkflowExecutionResponse, +) from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution from cadence.api.v1.tasklist_pb2 import TaskList from cadence.data_converter import DataConverter, DefaultDataConverter from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter - -@dataclass -class StartWorkflowOptions: +class StartWorkflowOptions(TypedDict, total=False): """Options for starting a workflow execution.""" + task_list: str - execution_start_to_close_timeout: Optional[timedelta] = None - task_start_to_close_timeout: Optional[timedelta] = None - workflow_id: Optional[str] = None - cron_schedule: Optional[str] = None + execution_start_to_close_timeout: timedelta + workflow_id: str + task_start_to_close_timeout: timedelta + cron_schedule: str + + +def _validate_and_apply_defaults(options: StartWorkflowOptions) -> StartWorkflowOptions: + """Validate required fields and apply defaults to StartWorkflowOptions.""" + if not options.get("task_list"): + raise ValueError("task_list is required") + + execution_timeout = options.get("execution_start_to_close_timeout") + if not execution_timeout: + raise ValueError("execution_start_to_close_timeout is required") + if execution_timeout <= timedelta(0): + raise ValueError("execution_start_to_close_timeout must be greater than 0") + + # Apply default for task_start_to_close_timeout if not provided (matching Go/Java clients) + task_timeout = options.get("task_start_to_close_timeout") + if task_timeout is None: + options["task_start_to_close_timeout"] = timedelta(seconds=10) + elif task_timeout <= timedelta(0): + raise ValueError("task_start_to_close_timeout must be greater than 0") - def __post_init__(self): - """Validate required fields after initialization.""" - if not self.task_list: - raise ValueError("task_list is required") - if not self.execution_start_to_close_timeout and not self.task_start_to_close_timeout: - raise ValueError("either execution_start_to_close_timeout or task_start_to_close_timeout is required") + return options class ClientOptions(TypedDict, total=False): @@ -53,6 +69,7 @@ class ClientOptions(TypedDict, total=False): metrics_emitter: MetricsEmitter interceptors: list[ClientInterceptor] + _DEFAULT_OPTIONS: ClientOptions = { "data_converter": DefaultDataConverter(), "identity": f"{os.getpid()}@{socket.gethostname()}", @@ -65,6 +82,7 @@ class ClientOptions(TypedDict, total=False): "interceptors": [], } + class Client: def __init__(self, **kwargs: Unpack[ClientOptions]) -> None: self._options = _validate_and_copy_defaults(ClientOptions(**kwargs)) @@ -107,7 +125,7 @@ async def ready(self) -> None: async def close(self) -> None: await self._channel.close() - async def __aenter__(self) -> 'Client': + async def __aenter__(self) -> "Client": return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: @@ -117,18 +135,18 @@ async def _build_start_workflow_request( self, workflow: Union[str, Callable], args: tuple[Any, ...], - options: StartWorkflowOptions + options: StartWorkflowOptions, ) -> StartWorkflowExecutionRequest: """Build a StartWorkflowExecutionRequest from parameters.""" # Generate workflow ID if not provided - workflow_id = options.workflow_id or str(uuid.uuid4()) + workflow_id = options.get("workflow_id") or str(uuid.uuid4()) # Determine workflow type name if isinstance(workflow, str): workflow_type_name = workflow else: # For callable, use function name or __name__ attribute - workflow_type_name = getattr(workflow, '__name__', str(workflow)) + workflow_type_name = getattr(workflow, "__name__", str(workflow)) # Encode input arguments input_payload = None @@ -139,35 +157,31 @@ async def _build_start_workflow_request( raise ValueError(f"Failed to encode workflow arguments: {e}") # Convert timedelta to protobuf Duration - execution_timeout = None - if options.execution_start_to_close_timeout: - execution_timeout = Duration() - execution_timeout.FromTimedelta(options.execution_start_to_close_timeout) + execution_timeout = Duration() + execution_timeout.FromTimedelta(options["execution_start_to_close_timeout"]) - task_timeout = None - if options.task_start_to_close_timeout: - task_timeout = Duration() - task_timeout.FromTimedelta(options.task_start_to_close_timeout) + task_timeout = Duration() + task_timeout.FromTimedelta(options["task_start_to_close_timeout"]) # Build the request request = StartWorkflowExecutionRequest( domain=self.domain, workflow_id=workflow_id, workflow_type=WorkflowType(name=workflow_type_name), - task_list=TaskList(name=options.task_list), + task_list=TaskList(name=options["task_list"]), identity=self.identity, - request_id=str(uuid.uuid4()) + request_id=str(uuid.uuid4()), ) + # Set required timeout fields + request.execution_start_to_close_timeout.CopyFrom(execution_timeout) + request.task_start_to_close_timeout.CopyFrom(task_timeout) + # Set optional fields if input_payload: request.input.CopyFrom(input_payload) - if execution_timeout: - request.execution_start_to_close_timeout.CopyFrom(execution_timeout) - if task_timeout: - request.task_start_to_close_timeout.CopyFrom(task_timeout) - if options.cron_schedule: - request.cron_schedule = options.cron_schedule + if options.get("cron_schedule"): + request.cron_schedule = options["cron_schedule"] return request @@ -175,7 +189,7 @@ async def start_workflow( self, workflow: Union[str, Callable], *args, - **options_kwargs + **options_kwargs: Unpack[StartWorkflowOptions], ) -> WorkflowExecution: """ Start a workflow execution asynchronously. @@ -192,15 +206,17 @@ async def start_workflow( ValueError: If required parameters are missing or invalid Exception: If the gRPC call fails """ - # Convert kwargs to StartWorkflowOptions - options = StartWorkflowOptions(**options_kwargs) + # Convert kwargs to StartWorkflowOptions and validate + options = _validate_and_apply_defaults(StartWorkflowOptions(options_kwargs)) # Build the gRPC request request = await self._build_start_workflow_request(workflow, args, options) # Execute the gRPC call try: - response: StartWorkflowExecutionResponse = await self.workflow_stub.StartWorkflowExecution(request) + response: StartWorkflowExecutionResponse = ( + await self.workflow_stub.StartWorkflowExecution(request) + ) # Emit metrics if available if self.metrics_emitter: @@ -211,10 +227,8 @@ async def start_workflow( execution.workflow_id = request.workflow_id execution.run_id = response.run_id return execution - except Exception as e: - raise Exception(f"Failed to start workflow: {e}") from e - - + except Exception: + raise def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: @@ -234,11 +248,24 @@ def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: def _create_channel(options: ClientOptions) -> Channel: interceptors = list(options["interceptors"]) - interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"])) + interceptors.append( + YarpcMetadataInterceptor(options["service_name"], options["caller_name"]) + ) interceptors.append(RetryInterceptor()) interceptors.append(CadenceErrorInterceptor()) if options["credentials"]: - return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors) + return secure_channel( + options["target"], + options["credentials"], + options["channel_arguments"], + options["compression"], + interceptors, + ) else: - return insecure_channel(options["target"], options["channel_arguments"], options["compression"], interceptors) + return insecure_channel( + options["target"], + options["channel_arguments"], + options["compression"], + interceptors, + ) diff --git a/tests/cadence/test_client_workflow.py b/tests/cadence/test_client_workflow.py index 5ae2ea5..9f6a7c3 100644 --- a/tests/cadence/test_client_workflow.py +++ b/tests/cadence/test_client_workflow.py @@ -4,8 +4,11 @@ from unittest.mock import AsyncMock, Mock, PropertyMock from cadence.api.v1.common_pb2 import WorkflowExecution -from cadence.api.v1.service_workflow_pb2 import StartWorkflowExecutionRequest, StartWorkflowExecutionResponse -from cadence.client import Client, StartWorkflowOptions +from cadence.api.v1.service_workflow_pb2 import ( + StartWorkflowExecutionRequest, + StartWorkflowExecutionResponse, +) +from cadence.client import Client, StartWorkflowOptions, _validate_and_apply_defaults from cadence.data_converter import DefaultDataConverter @@ -26,20 +29,20 @@ def mock_client(): class TestStartWorkflowOptions: - """Test StartWorkflowOptions dataclass.""" + """Test StartWorkflowOptions TypedDict and validation.""" def test_default_values(self): """Test default values for StartWorkflowOptions.""" options = StartWorkflowOptions( task_list="test-task-list", execution_start_to_close_timeout=timedelta(minutes=10), - task_start_to_close_timeout=timedelta(seconds=30) + task_start_to_close_timeout=timedelta(seconds=30), ) - assert options.workflow_id is None - assert options.task_list == "test-task-list" - assert options.execution_start_to_close_timeout == timedelta(minutes=10) - assert options.task_start_to_close_timeout == timedelta(seconds=30) - assert options.cron_schedule is None + assert options.get("workflow_id") is None + assert options["task_list"] == "test-task-list" + assert options["execution_start_to_close_timeout"] == timedelta(minutes=10) + assert options["task_start_to_close_timeout"] == timedelta(seconds=30) + assert options.get("cron_schedule") is None def test_custom_values(self): """Test setting custom values for StartWorkflowOptions.""" @@ -48,16 +51,14 @@ def test_custom_values(self): task_list="test-task-list", execution_start_to_close_timeout=timedelta(minutes=30), task_start_to_close_timeout=timedelta(seconds=10), - cron_schedule="0 * * * *" + cron_schedule="0 * * * *", ) - assert options.workflow_id == "custom-id" - assert options.task_list == "test-task-list" - assert options.execution_start_to_close_timeout == timedelta(minutes=30) - assert options.task_start_to_close_timeout == timedelta(seconds=10) - assert options.cron_schedule == "0 * * * *" - - + assert options["workflow_id"] == "custom-id" + assert options["task_list"] == "test-task-list" + assert options["execution_start_to_close_timeout"] == timedelta(minutes=30) + assert options["task_start_to_close_timeout"] == timedelta(seconds=10) + assert options["cron_schedule"] == "0 * * * *" class TestClientBuildStartWorkflowRequest: @@ -73,10 +74,12 @@ async def test_build_request_with_string_workflow(self, mock_client): workflow_id="test-workflow-id", task_list="test-task-list", execution_start_to_close_timeout=timedelta(minutes=30), - task_start_to_close_timeout=timedelta(seconds=10) + task_start_to_close_timeout=timedelta(seconds=10), ) - request = await client._build_start_workflow_request("TestWorkflow", ("arg1", "arg2"), options) + request = await client._build_start_workflow_request( + "TestWorkflow", ("arg1", "arg2"), options + ) assert isinstance(request, StartWorkflowExecutionRequest) assert request.domain == "test-domain" @@ -84,7 +87,9 @@ async def test_build_request_with_string_workflow(self, mock_client): assert request.workflow_type.name == "TestWorkflow" assert request.task_list.name == "test-task-list" assert request.identity == client.identity - assert request.workflow_id_reuse_policy == 0 # Default protobuf value when not set + assert ( + request.workflow_id_reuse_policy == 0 + ) # Default protobuf value when not set assert request.request_id != "" # Should be a UUID # Verify UUID format @@ -93,6 +98,7 @@ async def test_build_request_with_string_workflow(self, mock_client): @pytest.mark.asyncio async def test_build_request_with_callable_workflow(self, mock_client): """Test building request with callable workflow.""" + def test_workflow(): pass @@ -101,7 +107,7 @@ def test_workflow(): options = StartWorkflowOptions( task_list="test-task-list", execution_start_to_close_timeout=timedelta(minutes=10), - task_start_to_close_timeout=timedelta(seconds=30) + task_start_to_close_timeout=timedelta(seconds=30), ) request = await client._build_start_workflow_request(test_workflow, (), options) @@ -116,43 +122,64 @@ async def test_build_request_generates_workflow_id(self, mock_client): options = StartWorkflowOptions( task_list="test-task-list", execution_start_to_close_timeout=timedelta(minutes=10), - task_start_to_close_timeout=timedelta(seconds=30) + task_start_to_close_timeout=timedelta(seconds=30), ) - request = await client._build_start_workflow_request("TestWorkflow", (), options) + request = await client._build_start_workflow_request( + "TestWorkflow", (), options + ) assert request.workflow_id != "" # Verify it's a valid UUID uuid.UUID(request.workflow_id) - @pytest.mark.asyncio - async def test_build_request_missing_task_list(self, mock_client): + def test_missing_task_list_raises_error(self): """Test that missing task_list raises ValueError.""" - with pytest.raises(TypeError): # task_list is now required positional argument - StartWorkflowOptions() # No task_list - - def test_missing_timeout_raises_error(self): - """Test that missing both timeouts raises ValueError.""" - with pytest.raises(ValueError, match="either execution_start_to_close_timeout or task_start_to_close_timeout is required"): - StartWorkflowOptions(task_list="test-task-list") + options = StartWorkflowOptions() + with pytest.raises(ValueError, match="task_list is required"): + _validate_and_apply_defaults(options) + + def test_missing_execution_timeout_raises_error(self): + """Test that missing execution_start_to_close_timeout raises ValueError.""" + options = StartWorkflowOptions(task_list="test-task-list") + with pytest.raises( + ValueError, match="execution_start_to_close_timeout is required" + ): + _validate_and_apply_defaults(options) def test_only_execution_timeout(self): - """Test that only execution_start_to_close_timeout is valid.""" + """Test that only execution_start_to_close_timeout works with default task timeout.""" options = StartWorkflowOptions( task_list="test-task-list", - execution_start_to_close_timeout=timedelta(minutes=10) + execution_start_to_close_timeout=timedelta(minutes=10), ) - assert options.execution_start_to_close_timeout == timedelta(minutes=10) - assert options.task_start_to_close_timeout is None + validated_options = _validate_and_apply_defaults(options) + assert validated_options["execution_start_to_close_timeout"] == timedelta( + minutes=10 + ) + assert validated_options["task_start_to_close_timeout"] == timedelta( + seconds=10 + ) # Default applied + + def test_default_task_timeout(self): + """Test that task_start_to_close_timeout defaults to 10 seconds when not provided.""" + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=5), + ) + validated_options = _validate_and_apply_defaults(options) + assert validated_options["task_start_to_close_timeout"] == timedelta(seconds=10) def test_only_task_timeout(self): - """Test that only task_start_to_close_timeout is valid.""" + """Test that only task_start_to_close_timeout raises ValueError (execution timeout required).""" options = StartWorkflowOptions( task_list="test-task-list", - task_start_to_close_timeout=timedelta(seconds=30) + task_start_to_close_timeout=timedelta(seconds=30), ) - assert options.execution_start_to_close_timeout is None - assert options.task_start_to_close_timeout == timedelta(seconds=30) + with pytest.raises( + ValueError, match="execution_start_to_close_timeout is required" + ): + _validate_and_apply_defaults(options) @pytest.mark.asyncio async def test_build_request_with_input_args(self, mock_client): @@ -162,10 +189,12 @@ async def test_build_request_with_input_args(self, mock_client): options = StartWorkflowOptions( task_list="test-task-list", execution_start_to_close_timeout=timedelta(minutes=10), - task_start_to_close_timeout=timedelta(seconds=30) + task_start_to_close_timeout=timedelta(seconds=30), ) - request = await client._build_start_workflow_request("TestWorkflow", ("arg1", 42, {"key": "value"}), options) + request = await client._build_start_workflow_request( + "TestWorkflow", ("arg1", 42, {"key": "value"}), options + ) # Should have input payload assert request.HasField("input") @@ -179,10 +208,12 @@ async def test_build_request_with_timeouts(self, mock_client): options = StartWorkflowOptions( task_list="test-task-list", execution_start_to_close_timeout=timedelta(minutes=30), - task_start_to_close_timeout=timedelta(seconds=10) + task_start_to_close_timeout=timedelta(seconds=10), ) - request = await client._build_start_workflow_request("TestWorkflow", (), options) + request = await client._build_start_workflow_request( + "TestWorkflow", (), options + ) assert request.HasField("execution_start_to_close_timeout") assert request.HasField("task_start_to_close_timeout") @@ -200,10 +231,12 @@ async def test_build_request_with_cron_schedule(self, mock_client): task_list="test-task-list", execution_start_to_close_timeout=timedelta(minutes=10), task_start_to_close_timeout=timedelta(seconds=30), - cron_schedule="0 * * * *" + cron_schedule="0 * * * *", ) - request = await client._build_start_workflow_request("TestWorkflow", (), options) + request = await client._build_start_workflow_request( + "TestWorkflow", (), options + ) assert request.cron_schedule == "0 * * * *" @@ -218,7 +251,9 @@ async def test_start_workflow_success(self, mock_client): response = StartWorkflowExecutionResponse() response.run_id = "test-run-id" - mock_client.workflow_stub.StartWorkflowExecution = AsyncMock(return_value=response) + mock_client.workflow_stub.StartWorkflowExecution = AsyncMock( + return_value=response + ) # Create a real client but replace the workflow_stub client = Client(domain="test-domain", target="localhost:7933") @@ -235,11 +270,12 @@ async def mock_build_request(workflow, args, options): execution = await client.start_workflow( "TestWorkflow", - "arg1", "arg2", + "arg1", + "arg2", task_list="test-task-list", workflow_id="test-workflow-id", execution_start_to_close_timeout=timedelta(minutes=10), - task_start_to_close_timeout=timedelta(seconds=30) + task_start_to_close_timeout=timedelta(seconds=30), ) assert isinstance(execution, WorkflowExecution) @@ -253,20 +289,24 @@ async def mock_build_request(workflow, args, options): async def test_start_workflow_grpc_error(self, mock_client): """Test workflow start with gRPC error.""" # Setup mock to raise exception - mock_client.workflow_stub.StartWorkflowExecution = AsyncMock(side_effect=Exception("gRPC error")) + mock_client.workflow_stub.StartWorkflowExecution = AsyncMock( + side_effect=Exception("gRPC error") + ) client = Client(domain="test-domain", target="localhost:7933") client._workflow_stub = mock_client.workflow_stub # Mock the internal method - client._build_start_workflow_request = AsyncMock(return_value=StartWorkflowExecutionRequest()) + client._build_start_workflow_request = AsyncMock( + return_value=StartWorkflowExecutionRequest() + ) - with pytest.raises(Exception, match="Failed to start workflow: gRPC error"): + with pytest.raises(Exception, match="gRPC error"): await client.start_workflow( "TestWorkflow", task_list="test-task-list", execution_start_to_close_timeout=timedelta(minutes=10), - task_start_to_close_timeout=timedelta(seconds=30) + task_start_to_close_timeout=timedelta(seconds=30), ) @pytest.mark.asyncio @@ -275,13 +315,16 @@ async def test_start_workflow_with_kwargs(self, mock_client): response = StartWorkflowExecutionResponse() response.run_id = "test-run-id" - mock_client.workflow_stub.StartWorkflowExecution = AsyncMock(return_value=response) + mock_client.workflow_stub.StartWorkflowExecution = AsyncMock( + return_value=response + ) client = Client(domain="test-domain", target="localhost:7933") client._workflow_stub = mock_client.workflow_stub # Mock the internal method to capture options captured_options = None + async def mock_build_request(workflow, args, options): nonlocal captured_options captured_options = options @@ -297,15 +340,50 @@ async def mock_build_request(workflow, args, options): task_list="test-task-list", workflow_id="custom-id", execution_start_to_close_timeout=timedelta(minutes=30), - task_start_to_close_timeout=timedelta(seconds=30) + task_start_to_close_timeout=timedelta(seconds=30), ) # Verify options were properly constructed - assert captured_options.task_list == "test-task-list" - assert captured_options.workflow_id == "custom-id" - assert captured_options.execution_start_to_close_timeout == timedelta(minutes=30) + assert captured_options["task_list"] == "test-task-list" + assert captured_options["workflow_id"] == "custom-id" + assert captured_options["execution_start_to_close_timeout"] == timedelta( + minutes=30 + ) + + @pytest.mark.asyncio + async def test_start_workflow_with_default_task_timeout(self, mock_client): + """Test start_workflow uses default task timeout when not provided.""" + response = StartWorkflowExecutionResponse() + response.run_id = "test-run-id" + + mock_client.workflow_stub.StartWorkflowExecution = AsyncMock( + return_value=response + ) + + client = Client(domain="test-domain", target="localhost:7933") + client._workflow_stub = mock_client.workflow_stub + + # Mock the internal method to capture options + captured_options = None + async def mock_build_request(workflow, args, options): + nonlocal captured_options + captured_options = options + request = StartWorkflowExecutionRequest() + request.workflow_id = "test-workflow-id" + return request + + client._build_start_workflow_request = AsyncMock(side_effect=mock_build_request) + + await client.start_workflow( + "TestWorkflow", + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + # No task_start_to_close_timeout provided - should use default + ) + # Verify default was applied + assert captured_options["task_start_to_close_timeout"] == timedelta(seconds=10) @pytest.mark.asyncio @@ -329,7 +407,7 @@ async def test_integration_workflow_invocation(): task_list="integration-task-list", workflow_id="integration-workflow-id", execution_start_to_close_timeout=timedelta(minutes=10), - task_start_to_close_timeout=timedelta(seconds=30) + task_start_to_close_timeout=timedelta(seconds=30), ) # Verify result @@ -345,4 +423,4 @@ async def test_integration_workflow_invocation(): assert request.workflow_type.name == "IntegrationTestWorkflow" assert request.task_list.name == "integration-task-list" assert request.HasField("input") # Should have encoded input - assert request.HasField("execution_start_to_close_timeout") \ No newline at end of file + assert request.HasField("execution_start_to_close_timeout") From ce1f8f20210e49ca99e337592e8e963742b7cf90 Mon Sep 17 00:00:00 2001 From: Tim Li Date: Tue, 7 Oct 2025 16:01:15 -0700 Subject: [PATCH 8/8] fix linter Signed-off-by: Tim Li --- cadence/client.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cadence/client.py b/cadence/client.py index e36ffd6..dcad7f3 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -1,9 +1,8 @@ import os import socket import uuid -from dataclasses import dataclass from datetime import timedelta -from typing import TypedDict, Unpack, Any, cast, Union, Optional, Callable +from typing import TypedDict, Unpack, Any, cast, Union, Callable from grpc import ChannelCredentials, Compression from google.protobuf.duration_pb2 import Duration @@ -207,7 +206,7 @@ async def start_workflow( Exception: If the gRPC call fails """ # Convert kwargs to StartWorkflowOptions and validate - options = _validate_and_apply_defaults(StartWorkflowOptions(options_kwargs)) + options = _validate_and_apply_defaults(StartWorkflowOptions(**options_kwargs)) # Build the gRPC request request = await self._build_start_workflow_request(workflow, args, options)