diff --git a/cadence/client.py b/cadence/client.py index 77ec95c..dcad7f3 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -1,8 +1,11 @@ import os import socket -from typing import TypedDict, Unpack, Any, cast +import uuid +from datetime import timedelta +from typing import TypedDict, Unpack, Any, cast, Union, Callable from grpc import ChannelCredentials, Compression +from google.protobuf.duration_pb2 import Duration from cadence._internal.rpc.error import CadenceErrorInterceptor from cadence._internal.rpc.retry import RetryInterceptor @@ -11,10 +14,47 @@ from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel from cadence.api.v1.service_workflow_pb2_grpc import WorkflowAPIStub +from cadence.api.v1.service_workflow_pb2 import ( + StartWorkflowExecutionRequest, + StartWorkflowExecutionResponse, +) +from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution +from cadence.api.v1.tasklist_pb2 import TaskList from cadence.data_converter import DataConverter, DefaultDataConverter from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter +class StartWorkflowOptions(TypedDict, total=False): + """Options for starting a workflow execution.""" + + task_list: str + execution_start_to_close_timeout: timedelta + workflow_id: str + task_start_to_close_timeout: timedelta + cron_schedule: str + + +def _validate_and_apply_defaults(options: StartWorkflowOptions) -> StartWorkflowOptions: + """Validate required fields and apply defaults to StartWorkflowOptions.""" + if not options.get("task_list"): + raise ValueError("task_list is required") + + execution_timeout = options.get("execution_start_to_close_timeout") + if not execution_timeout: + raise ValueError("execution_start_to_close_timeout is required") + if execution_timeout <= timedelta(0): + raise ValueError("execution_start_to_close_timeout must be greater than 0") + + # Apply default for task_start_to_close_timeout if not provided (matching Go/Java clients) + task_timeout = options.get("task_start_to_close_timeout") + if task_timeout is None: + options["task_start_to_close_timeout"] = timedelta(seconds=10) + elif task_timeout <= timedelta(0): + raise ValueError("task_start_to_close_timeout must be greater than 0") + + return options + + class ClientOptions(TypedDict, total=False): domain: str target: str @@ -28,6 +68,7 @@ class ClientOptions(TypedDict, total=False): metrics_emitter: MetricsEmitter interceptors: list[ClientInterceptor] + _DEFAULT_OPTIONS: ClientOptions = { "data_converter": DefaultDataConverter(), "identity": f"{os.getpid()}@{socket.gethostname()}", @@ -40,6 +81,7 @@ class ClientOptions(TypedDict, total=False): "interceptors": [], } + class Client: def __init__(self, **kwargs: Unpack[ClientOptions]) -> None: self._options = _validate_and_copy_defaults(ClientOptions(**kwargs)) @@ -82,12 +124,112 @@ async def ready(self) -> None: async def close(self) -> None: await self._channel.close() - async def __aenter__(self) -> 'Client': + async def __aenter__(self) -> "Client": return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.close() + async def _build_start_workflow_request( + self, + workflow: Union[str, Callable], + args: tuple[Any, ...], + options: StartWorkflowOptions, + ) -> StartWorkflowExecutionRequest: + """Build a StartWorkflowExecutionRequest from parameters.""" + # Generate workflow ID if not provided + workflow_id = options.get("workflow_id") or str(uuid.uuid4()) + + # Determine workflow type name + if isinstance(workflow, str): + workflow_type_name = workflow + else: + # For callable, use function name or __name__ attribute + workflow_type_name = getattr(workflow, "__name__", str(workflow)) + + # Encode input arguments + input_payload = None + if args: + try: + input_payload = await self.data_converter.to_data(list(args)) + except Exception as e: + raise ValueError(f"Failed to encode workflow arguments: {e}") + + # Convert timedelta to protobuf Duration + execution_timeout = Duration() + execution_timeout.FromTimedelta(options["execution_start_to_close_timeout"]) + + task_timeout = Duration() + task_timeout.FromTimedelta(options["task_start_to_close_timeout"]) + + # Build the request + request = StartWorkflowExecutionRequest( + domain=self.domain, + workflow_id=workflow_id, + workflow_type=WorkflowType(name=workflow_type_name), + task_list=TaskList(name=options["task_list"]), + identity=self.identity, + request_id=str(uuid.uuid4()), + ) + + # Set required timeout fields + request.execution_start_to_close_timeout.CopyFrom(execution_timeout) + request.task_start_to_close_timeout.CopyFrom(task_timeout) + + # Set optional fields + if input_payload: + request.input.CopyFrom(input_payload) + if options.get("cron_schedule"): + request.cron_schedule = options["cron_schedule"] + + return request + + async def start_workflow( + self, + workflow: Union[str, Callable], + *args, + **options_kwargs: Unpack[StartWorkflowOptions], + ) -> WorkflowExecution: + """ + Start a workflow execution asynchronously. + + Args: + workflow: Workflow function or workflow type name string + *args: Arguments to pass to the workflow + **options_kwargs: StartWorkflowOptions as keyword arguments + + Returns: + WorkflowExecution with workflow_id and run_id + + Raises: + ValueError: If required parameters are missing or invalid + Exception: If the gRPC call fails + """ + # Convert kwargs to StartWorkflowOptions and validate + options = _validate_and_apply_defaults(StartWorkflowOptions(**options_kwargs)) + + # Build the gRPC request + request = await self._build_start_workflow_request(workflow, args, options) + + # Execute the gRPC call + try: + response: StartWorkflowExecutionResponse = ( + await self.workflow_stub.StartWorkflowExecution(request) + ) + + # Emit metrics if available + if self.metrics_emitter: + # TODO: Add workflow start metrics similar to Go client + pass + + execution = WorkflowExecution() + execution.workflow_id = request.workflow_id + execution.run_id = response.run_id + return execution + except Exception: + raise + + def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: if "target" not in options: raise ValueError("target must be specified") @@ -105,11 +247,24 @@ def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: def _create_channel(options: ClientOptions) -> Channel: interceptors = list(options["interceptors"]) - interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"])) + interceptors.append( + YarpcMetadataInterceptor(options["service_name"], options["caller_name"]) + ) interceptors.append(RetryInterceptor()) interceptors.append(CadenceErrorInterceptor()) if options["credentials"]: - return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors) + return secure_channel( + options["target"], + options["credentials"], + options["channel_arguments"], + options["compression"], + interceptors, + ) else: - return insecure_channel(options["target"], options["channel_arguments"], options["compression"], interceptors) + return insecure_channel( + options["target"], + options["channel_arguments"], + options["compression"], + interceptors, + ) diff --git a/tests/cadence/test_client_workflow.py b/tests/cadence/test_client_workflow.py new file mode 100644 index 0000000..9f6a7c3 --- /dev/null +++ b/tests/cadence/test_client_workflow.py @@ -0,0 +1,426 @@ +import pytest +import uuid +from datetime import timedelta +from unittest.mock import AsyncMock, Mock, PropertyMock + +from cadence.api.v1.common_pb2 import WorkflowExecution +from cadence.api.v1.service_workflow_pb2 import ( + StartWorkflowExecutionRequest, + StartWorkflowExecutionResponse, +) +from cadence.client import Client, StartWorkflowOptions, _validate_and_apply_defaults +from cadence.data_converter import DefaultDataConverter + + +@pytest.fixture +def mock_client(): + """Create a mock client for testing.""" + client = Mock(spec=Client) + type(client).domain = PropertyMock(return_value="test-domain") + type(client).identity = PropertyMock(return_value="test-identity") + type(client).data_converter = PropertyMock(return_value=DefaultDataConverter()) + type(client).metrics_emitter = PropertyMock(return_value=None) + + # Mock the workflow stub + workflow_stub = Mock() + type(client).workflow_stub = PropertyMock(return_value=workflow_stub) + + return client + + +class TestStartWorkflowOptions: + """Test StartWorkflowOptions TypedDict and validation.""" + + def test_default_values(self): + """Test default values for StartWorkflowOptions.""" + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30), + ) + assert options.get("workflow_id") is None + assert options["task_list"] == "test-task-list" + assert options["execution_start_to_close_timeout"] == timedelta(minutes=10) + assert options["task_start_to_close_timeout"] == timedelta(seconds=30) + assert options.get("cron_schedule") is None + + def test_custom_values(self): + """Test setting custom values for StartWorkflowOptions.""" + options = StartWorkflowOptions( + workflow_id="custom-id", + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=30), + task_start_to_close_timeout=timedelta(seconds=10), + cron_schedule="0 * * * *", + ) + + assert options["workflow_id"] == "custom-id" + assert options["task_list"] == "test-task-list" + assert options["execution_start_to_close_timeout"] == timedelta(minutes=30) + assert options["task_start_to_close_timeout"] == timedelta(seconds=10) + assert options["cron_schedule"] == "0 * * * *" + + +class TestClientBuildStartWorkflowRequest: + """Test Client._build_start_workflow_request method.""" + + @pytest.mark.asyncio + async def test_build_request_with_string_workflow(self, mock_client): + """Test building request with string workflow name.""" + # Create real client instance to test the method + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions( + workflow_id="test-workflow-id", + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=30), + task_start_to_close_timeout=timedelta(seconds=10), + ) + + request = await client._build_start_workflow_request( + "TestWorkflow", ("arg1", "arg2"), options + ) + + assert isinstance(request, StartWorkflowExecutionRequest) + assert request.domain == "test-domain" + assert request.workflow_id == "test-workflow-id" + assert request.workflow_type.name == "TestWorkflow" + assert request.task_list.name == "test-task-list" + assert request.identity == client.identity + assert ( + request.workflow_id_reuse_policy == 0 + ) # Default protobuf value when not set + assert request.request_id != "" # Should be a UUID + + # Verify UUID format + uuid.UUID(request.request_id) # This will raise if not valid UUID + + @pytest.mark.asyncio + async def test_build_request_with_callable_workflow(self, mock_client): + """Test building request with callable workflow.""" + + def test_workflow(): + pass + + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30), + ) + + request = await client._build_start_workflow_request(test_workflow, (), options) + + assert request.workflow_type.name == "test_workflow" + + @pytest.mark.asyncio + async def test_build_request_generates_workflow_id(self, mock_client): + """Test that workflow_id is generated when not provided.""" + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30), + ) + + request = await client._build_start_workflow_request( + "TestWorkflow", (), options + ) + + assert request.workflow_id != "" + # Verify it's a valid UUID + uuid.UUID(request.workflow_id) + + def test_missing_task_list_raises_error(self): + """Test that missing task_list raises ValueError.""" + options = StartWorkflowOptions() + with pytest.raises(ValueError, match="task_list is required"): + _validate_and_apply_defaults(options) + + def test_missing_execution_timeout_raises_error(self): + """Test that missing execution_start_to_close_timeout raises ValueError.""" + options = StartWorkflowOptions(task_list="test-task-list") + with pytest.raises( + ValueError, match="execution_start_to_close_timeout is required" + ): + _validate_and_apply_defaults(options) + + def test_only_execution_timeout(self): + """Test that only execution_start_to_close_timeout works with default task timeout.""" + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + ) + validated_options = _validate_and_apply_defaults(options) + assert validated_options["execution_start_to_close_timeout"] == timedelta( + minutes=10 + ) + assert validated_options["task_start_to_close_timeout"] == timedelta( + seconds=10 + ) # Default applied + + def test_default_task_timeout(self): + """Test that task_start_to_close_timeout defaults to 10 seconds when not provided.""" + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=5), + ) + validated_options = _validate_and_apply_defaults(options) + assert validated_options["task_start_to_close_timeout"] == timedelta(seconds=10) + + def test_only_task_timeout(self): + """Test that only task_start_to_close_timeout raises ValueError (execution timeout required).""" + options = StartWorkflowOptions( + task_list="test-task-list", + task_start_to_close_timeout=timedelta(seconds=30), + ) + with pytest.raises( + ValueError, match="execution_start_to_close_timeout is required" + ): + _validate_and_apply_defaults(options) + + @pytest.mark.asyncio + async def test_build_request_with_input_args(self, mock_client): + """Test building request with input arguments.""" + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30), + ) + + request = await client._build_start_workflow_request( + "TestWorkflow", ("arg1", 42, {"key": "value"}), options + ) + + # Should have input payload + assert request.HasField("input") + assert len(request.input.data) > 0 + + @pytest.mark.asyncio + async def test_build_request_with_timeouts(self, mock_client): + """Test building request with timeout settings.""" + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=30), + task_start_to_close_timeout=timedelta(seconds=10), + ) + + request = await client._build_start_workflow_request( + "TestWorkflow", (), options + ) + + assert request.HasField("execution_start_to_close_timeout") + assert request.HasField("task_start_to_close_timeout") + + # Check timeout values (30 minutes = 1800 seconds) + assert request.execution_start_to_close_timeout.seconds == 1800 + assert request.task_start_to_close_timeout.seconds == 10 + + @pytest.mark.asyncio + async def test_build_request_with_cron_schedule(self, mock_client): + """Test building request with cron schedule.""" + client = Client(domain="test-domain", target="localhost:7933") + + options = StartWorkflowOptions( + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30), + cron_schedule="0 * * * *", + ) + + request = await client._build_start_workflow_request( + "TestWorkflow", (), options + ) + + assert request.cron_schedule == "0 * * * *" + + +class TestClientStartWorkflow: + """Test Client.start_workflow method.""" + + @pytest.mark.asyncio + async def test_start_workflow_success(self, mock_client): + """Test successful workflow start.""" + # Setup mock response + response = StartWorkflowExecutionResponse() + response.run_id = "test-run-id" + + mock_client.workflow_stub.StartWorkflowExecution = AsyncMock( + return_value=response + ) + + # Create a real client but replace the workflow_stub + client = Client(domain="test-domain", target="localhost:7933") + client._workflow_stub = mock_client.workflow_stub + + # Mock the internal method to avoid full request building + async def mock_build_request(workflow, args, options): + request = StartWorkflowExecutionRequest() + request.workflow_id = "test-workflow-id" + request.domain = "test-domain" + return request + + client._build_start_workflow_request = AsyncMock(side_effect=mock_build_request) + + execution = await client.start_workflow( + "TestWorkflow", + "arg1", + "arg2", + task_list="test-task-list", + workflow_id="test-workflow-id", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30), + ) + + assert isinstance(execution, WorkflowExecution) + assert execution.workflow_id == "test-workflow-id" + assert execution.run_id == "test-run-id" + + # Verify the gRPC call was made + mock_client.workflow_stub.StartWorkflowExecution.assert_called_once() + + @pytest.mark.asyncio + async def test_start_workflow_grpc_error(self, mock_client): + """Test workflow start with gRPC error.""" + # Setup mock to raise exception + mock_client.workflow_stub.StartWorkflowExecution = AsyncMock( + side_effect=Exception("gRPC error") + ) + + client = Client(domain="test-domain", target="localhost:7933") + client._workflow_stub = mock_client.workflow_stub + + # Mock the internal method + client._build_start_workflow_request = AsyncMock( + return_value=StartWorkflowExecutionRequest() + ) + + with pytest.raises(Exception, match="gRPC error"): + await client.start_workflow( + "TestWorkflow", + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30), + ) + + @pytest.mark.asyncio + async def test_start_workflow_with_kwargs(self, mock_client): + """Test start_workflow with options as kwargs.""" + response = StartWorkflowExecutionResponse() + response.run_id = "test-run-id" + + mock_client.workflow_stub.StartWorkflowExecution = AsyncMock( + return_value=response + ) + + client = Client(domain="test-domain", target="localhost:7933") + client._workflow_stub = mock_client.workflow_stub + + # Mock the internal method to capture options + captured_options = None + + async def mock_build_request(workflow, args, options): + nonlocal captured_options + captured_options = options + request = StartWorkflowExecutionRequest() + request.workflow_id = "test-workflow-id" + return request + + client._build_start_workflow_request = AsyncMock(side_effect=mock_build_request) + + await client.start_workflow( + "TestWorkflow", + "arg1", + task_list="test-task-list", + workflow_id="custom-id", + execution_start_to_close_timeout=timedelta(minutes=30), + task_start_to_close_timeout=timedelta(seconds=30), + ) + + # Verify options were properly constructed + assert captured_options["task_list"] == "test-task-list" + assert captured_options["workflow_id"] == "custom-id" + assert captured_options["execution_start_to_close_timeout"] == timedelta( + minutes=30 + ) + + @pytest.mark.asyncio + async def test_start_workflow_with_default_task_timeout(self, mock_client): + """Test start_workflow uses default task timeout when not provided.""" + response = StartWorkflowExecutionResponse() + response.run_id = "test-run-id" + + mock_client.workflow_stub.StartWorkflowExecution = AsyncMock( + return_value=response + ) + + client = Client(domain="test-domain", target="localhost:7933") + client._workflow_stub = mock_client.workflow_stub + + # Mock the internal method to capture options + captured_options = None + + async def mock_build_request(workflow, args, options): + nonlocal captured_options + captured_options = options + request = StartWorkflowExecutionRequest() + request.workflow_id = "test-workflow-id" + return request + + client._build_start_workflow_request = AsyncMock(side_effect=mock_build_request) + + await client.start_workflow( + "TestWorkflow", + task_list="test-task-list", + execution_start_to_close_timeout=timedelta(minutes=10), + # No task_start_to_close_timeout provided - should use default + ) + + # Verify default was applied + assert captured_options["task_start_to_close_timeout"] == timedelta(seconds=10) + + +@pytest.mark.asyncio +async def test_integration_workflow_invocation(): + """Integration test for workflow invocation flow.""" + # This test verifies the complete flow works together + response = StartWorkflowExecutionResponse() + response.run_id = "integration-run-id" + + # Create client with mocked gRPC stub + client = Client(domain="test-domain", target="localhost:7933") + client._workflow_stub = Mock() + client._workflow_stub.StartWorkflowExecution = AsyncMock(return_value=response) + + # Test the complete flow + execution = await client.start_workflow( + "IntegrationTestWorkflow", + "test-arg", + 42, + {"data": "value"}, + task_list="integration-task-list", + workflow_id="integration-workflow-id", + execution_start_to_close_timeout=timedelta(minutes=10), + task_start_to_close_timeout=timedelta(seconds=30), + ) + + # Verify result + assert execution.workflow_id == "integration-workflow-id" + assert execution.run_id == "integration-run-id" + + # Verify the gRPC call was made with proper request + client._workflow_stub.StartWorkflowExecution.assert_called_once() + request = client._workflow_stub.StartWorkflowExecution.call_args[0][0] + + assert request.domain == "test-domain" + assert request.workflow_id == "integration-workflow-id" + assert request.workflow_type.name == "IntegrationTestWorkflow" + assert request.task_list.name == "integration-task-list" + assert request.HasField("input") # Should have encoded input + assert request.HasField("execution_start_to_close_timeout")