diff --git a/cadence/__init__.py b/cadence/__init__.py index 175f01b..c1c2a17 100644 --- a/cadence/__init__.py +++ b/cadence/__init__.py @@ -6,9 +6,13 @@ # Import main client functionality from .client import Client +from .worker import Registry +from . import workflow __version__ = "0.1.0" __all__ = [ "Client", + "Registry", + "workflow", ] diff --git a/cadence/_internal/workflow/workflow_engine.py b/cadence/_internal/workflow/workflow_engine.py index 2456cc1..67606b7 100644 --- a/cadence/_internal/workflow/workflow_engine.py +++ b/cadence/_internal/workflow/workflow_engine.py @@ -20,9 +20,12 @@ class DecisionResult: decisions: list[Decision] class WorkflowEngine: - def __init__(self, info: WorkflowInfo, client: Client, workflow_func: Callable[[Any], Any] | None = None): + def __init__(self, info: WorkflowInfo, client: Client, workflow_definition=None): self._context = Context(client, info) - self._workflow_func = workflow_func + self._workflow_definition = workflow_definition + self._workflow_instance = None + if workflow_definition: + self._workflow_instance = workflow_definition.cls() self._decision_manager = DecisionManager() self._decisions_helper = DecisionsHelper(self._decision_manager) self._is_workflow_complete = False @@ -250,19 +253,17 @@ def _fallback_process_workflow_history(self, history) -> None: async def _execute_workflow_function(self, decision_task: PollForDecisionTaskResponse) -> None: """ Execute the workflow function to generate new decisions. - + This blocks until the workflow schedules an activity or completes. - + Args: decision_task: The decision task containing workflow context """ try: - # Execute the workflow function - # The workflow function should block until it schedules an activity - workflow_func = self._workflow_func - if workflow_func is None: + # Execute the workflow function from the workflow instance + if self._workflow_definition is None or self._workflow_instance is None: logger.warning( - "No workflow function available", + "No workflow definition or instance available", extra={ "workflow_type": self._context.info().workflow_type, "workflow_id": self._context.info().workflow_id, @@ -271,11 +272,14 @@ async def _execute_workflow_function(self, decision_task: PollForDecisionTaskRes ) return + # Get the workflow run method from the instance + workflow_func = self._workflow_definition.get_run_method(self._workflow_instance) + # Extract workflow input from history workflow_input = await self._extract_workflow_input(decision_task) # Execute workflow function - result = self._execute_workflow_function_once(workflow_func, workflow_input) + result = await self._execute_workflow_function_once(workflow_func, workflow_input) # Check if workflow is complete if result is not None: @@ -290,7 +294,7 @@ async def _execute_workflow_function(self, decision_task: PollForDecisionTaskRes "completion_type": "success" } ) - + except Exception as e: logger.error( "Error executing workflow function", @@ -337,7 +341,7 @@ async def _extract_workflow_input(self, decision_task: PollForDecisionTaskRespon logger.warning("No WorkflowExecutionStarted event found in history") return None - def _execute_workflow_function_once(self, workflow_func: Callable, workflow_input: Any) -> Any: + async def _execute_workflow_function_once(self, workflow_func: Callable, workflow_input: Any) -> Any: """ Execute the workflow function once (not during replay). @@ -351,23 +355,9 @@ def _execute_workflow_function_once(self, workflow_func: Callable, workflow_inpu logger.debug(f"Executing workflow function with input: {workflow_input}") result = workflow_func(workflow_input) - # If the workflow function is async, we need to handle it properly + # If the workflow function is async, await it properly if asyncio.iscoroutine(result): - # For now, use asyncio.run for async workflow functions - # TODO: Implement proper deterministic event loop for workflow execution - try: - result = asyncio.run(result) - except RuntimeError: - # If we're already in an event loop, create a new task - loop = asyncio.get_event_loop() - if loop.is_running(): - # We can't use asyncio.run inside a running loop - # For now, just get the result (this may not be deterministic) - logger.warning("Async workflow function called within running event loop - may not be deterministic") - # This is a workaround - in a real implementation, we'd need proper task scheduling - result = None - else: - result = loop.run_until_complete(result) + result = await result return result diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py index d35ee66..62f0edb 100644 --- a/cadence/worker/_decision_task_handler.py +++ b/cadence/worker/_decision_task_handler.py @@ -76,7 +76,7 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - ) try: - workflow_func = self._registry.get_workflow(workflow_type_name) + workflow_definition = self._registry.get_workflow(workflow_type_name) except KeyError: logger.error( "Workflow type not found in registry", @@ -103,9 +103,9 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - workflow_engine = self._workflow_engines.get(cache_key) if workflow_engine is None: workflow_engine = WorkflowEngine( - info=workflow_info, - client=self._client, - workflow_func=workflow_func + info=workflow_info, + client=self._client, + workflow_definition=workflow_definition ) self._workflow_engines[cache_key] = workflow_engine diff --git a/cadence/worker/_registry.py b/cadence/worker/_registry.py index d60521d..f45c42a 100644 --- a/cadence/worker/_registry.py +++ b/cadence/worker/_registry.py @@ -7,11 +7,15 @@ """ import logging -from typing import Callable, Dict, Optional, Unpack, TypedDict, Sequence, overload +from typing import Callable, Dict, Optional, Unpack, TypedDict, overload, Type, Union, TypeVar from cadence.activity import ActivityDefinitionOptions, ActivityDefinition, ActivityDecorator, P, T +from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions logger = logging.getLogger(__name__) +# TypeVar for workflow class types +W = TypeVar('W') + class RegisterWorkflowOptions(TypedDict, total=False): """Options for registering a workflow.""" @@ -28,53 +32,58 @@ class Registry: def __init__(self) -> None: """Initialize the registry.""" - self._workflows: Dict[str, Callable] = {} + self._workflows: Dict[str, WorkflowDefinition] = {} self._activities: Dict[str, ActivityDefinition] = {} self._workflow_aliases: Dict[str, str] = {} # alias -> name mapping def workflow( self, - func: Optional[Callable] = None, + cls: Optional[Type[W]] = None, **kwargs: Unpack[RegisterWorkflowOptions] - ) -> Callable: + ) -> Union[Type[W], Callable[[Type[W]], Type[W]]]: """ - Register a workflow function. - + Register a workflow class. + This method can be used as a decorator or called directly. - + Only supports class-based workflows. + Args: - func: The workflow function to register + cls: The workflow class to register **kwargs: Options for registration (name, alias) - + Returns: - The decorated function or the function itself - + The decorated class + Raises: KeyError: If workflow name already exists + ValueError: If class workflow is invalid """ options = RegisterWorkflowOptions(**kwargs) - - def decorator(f: Callable) -> Callable: - workflow_name = options.get('name') or f.__name__ - + + def decorator(target: Type[W]) -> Type[W]: + workflow_name = options.get('name') or target.__name__ + if workflow_name in self._workflows: raise KeyError(f"Workflow '{workflow_name}' is already registered") - - self._workflows[workflow_name] = f - + + # Create WorkflowDefinition with type information + workflow_opts = WorkflowDefinitionOptions(name=workflow_name) + workflow_def = WorkflowDefinition.wrap(target, workflow_opts) + self._workflows[workflow_name] = workflow_def + # Register alias if provided alias = options.get('alias') if alias: if alias in self._workflow_aliases: raise KeyError(f"Workflow alias '{alias}' is already registered") self._workflow_aliases[alias] = workflow_name - + logger.info(f"Registered workflow '{workflow_name}'") - return f - - if func is None: + return target + + if cls is None: return decorator - return decorator(func) + return decorator(cls) @overload def activity(self, func: Callable[P, T]) -> ActivityDefinition[P, T]: @@ -135,25 +144,25 @@ def _register_activity(self, defn: ActivityDefinition) -> None: self._activities[defn.name] = defn - def get_workflow(self, name: str) -> Callable: + def get_workflow(self, name: str) -> WorkflowDefinition: """ Get a registered workflow by name. - + Args: name: Name or alias of the workflow - + Returns: - The workflow function - + The workflow definition + Raises: KeyError: If workflow is not found """ # Check if it's an alias actual_name = self._workflow_aliases.get(name, name) - + if actual_name not in self._workflows: raise KeyError(f"Workflow '{name}' not found in registry") - + return self._workflows[actual_name] def get_activity(self, name: str) -> ActivityDefinition: @@ -188,7 +197,7 @@ def of(*args: 'Registry') -> 'Registry': return result -def _find_activity_definitions(instance: object) -> Sequence[ActivityDefinition]: +def _find_activity_definitions(instance: object) -> list[ActivityDefinition]: attr_to_def = {} for t in instance.__class__.__mro__: for attr in dir(t): @@ -200,10 +209,7 @@ def _find_activity_definitions(instance: object) -> Sequence[ActivityDefinition] raise ValueError(f"'{attr}' was overridden with a duplicate activity definition") attr_to_def[attr] = value - # Create new definitions, copying the attributes from the declaring type but using the function - # from the specific object. This allows for the decorator to be applied to the base class and the - # function to be overridden - result = [] + result: list[ActivityDefinition] = [] for attr, definition in attr_to_def.items(): result.append(ActivityDefinition(getattr(instance, attr), definition.name, definition.strategy, definition.params)) diff --git a/cadence/workflow.py b/cadence/workflow.py index 51b968f..14cabec 100644 --- a/cadence/workflow.py +++ b/cadence/workflow.py @@ -2,10 +2,127 @@ from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass -from typing import Iterator +from typing import Iterator, Callable, TypeVar, TypedDict, Type, cast, Any, Optional, Union +import inspect from cadence.client import Client +T = TypeVar('T', bound=Callable[..., Any]) + + +class WorkflowDefinitionOptions(TypedDict, total=False): + """Options for defining a workflow.""" + name: str + + +class WorkflowDefinition: + """ + Definition of a workflow class with metadata. + + Similar to ActivityDefinition but for workflow classes. + Provides type safety and metadata for workflow classes. + """ + + def __init__(self, cls: Type, name: str, run_method_name: str): + self._cls = cls + self._name = name + self._run_method_name = run_method_name + + @property + def name(self) -> str: + """Get the workflow name.""" + return self._name + + @property + def cls(self) -> Type: + """Get the workflow class.""" + return self._cls + + def get_run_method(self, instance: Any) -> Callable: + """Get the workflow run method from an instance of the workflow class.""" + return cast(Callable, getattr(instance, self._run_method_name)) + + @staticmethod + def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> 'WorkflowDefinition': + """ + Wrap a class as a WorkflowDefinition. + + Args: + cls: The workflow class to wrap + opts: Options for the workflow definition + + Returns: + A WorkflowDefinition instance + + Raises: + ValueError: If no run method is found or multiple run methods exist + """ + name = cls.__name__ + if "name" in opts and opts["name"]: + name = opts["name"] + + # Validate that the class has exactly one run method and find it + run_method_name = None + for attr_name in dir(cls): + if attr_name.startswith('_'): + continue + + attr = getattr(cls, attr_name) + if not callable(attr): + continue + + # Check for workflow run method + if hasattr(attr, '_workflow_run'): + if run_method_name is not None: + raise ValueError(f"Multiple @workflow.run methods found in class {cls.__name__}") + run_method_name = attr_name + + if run_method_name is None: + raise ValueError(f"No @workflow.run method found in class {cls.__name__}") + + return WorkflowDefinition(cls, name, run_method_name) + + +def run(func: Optional[T] = None) -> Union[T, Callable[[T], T]]: + """ + Decorator to mark a method as the main workflow run method. + + Can be used with or without parentheses: + @workflow.run + async def my_workflow(self): + ... + + @workflow.run() + async def my_workflow(self): + ... + + Args: + func: The method to mark as the workflow run method + + Returns: + The decorated method with workflow run metadata + + Raises: + ValueError: If the function is not async + """ + def decorator(f: T) -> T: + # Validate that the function is async + if not inspect.iscoroutinefunction(f): + raise ValueError(f"Workflow run method '{f.__name__}' must be async") + + # Attach metadata to the function + f._workflow_run = True # type: ignore + return f + + # Support both @workflow.run and @workflow.run() + if func is None: + # Called with parentheses: @workflow.run() + return decorator + else: + # Called without parentheses: @workflow.run + return decorator(func) + + @dataclass class WorkflowInfo: workflow_type: str diff --git a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py index cb1f449..3805f56 100644 --- a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py +++ b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py @@ -9,7 +9,8 @@ from cadence.api.v1.common_pb2 import Payload, WorkflowExecution, WorkflowType from cadence.api.v1.history_pb2 import History, HistoryEvent, WorkflowExecutionStartedEventAttributes from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult -from cadence.workflow import WorkflowInfo +from cadence import workflow +from cadence.workflow import WorkflowInfo, WorkflowDefinition, WorkflowDefinitionOptions from cadence.client import Client @@ -36,19 +37,23 @@ def workflow_info(self): ) @pytest.fixture - def mock_workflow_func(self): - """Create a mock workflow function.""" - def workflow_func(input_data): - return f"processed: {input_data}" - return workflow_func + def mock_workflow_definition(self): + """Create a mock workflow definition.""" + class TestWorkflow: + @workflow.run + async def weird_name(self, input_data): + return f"processed: {input_data}" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + return WorkflowDefinition.wrap(TestWorkflow, workflow_opts) @pytest.fixture - def workflow_engine(self, mock_client, workflow_info, mock_workflow_func): + def workflow_engine(self, mock_client, workflow_info, mock_workflow_definition): """Create a WorkflowEngine instance.""" return WorkflowEngine( info=workflow_info, client=mock_client, - workflow_func=mock_workflow_func + workflow_definition=mock_workflow_definition ) def create_mock_decision_task(self, workflow_id="test-workflow", run_id="test-run", workflow_type="test_workflow"): @@ -208,17 +213,22 @@ async def test_extract_workflow_input_deserialization_error(self, workflow_engin # Verify no input was extracted due to error assert input_data is None - def test_execute_workflow_function_sync(self, workflow_engine): + @pytest.mark.asyncio + async def test_execute_workflow_function_sync(self, workflow_engine): """Test synchronous workflow function execution.""" input_data = "test-input" - + + # Get the workflow function from the instance + workflow_func = workflow_engine._workflow_definition.get_run_method(workflow_engine._workflow_instance) + # Execute the workflow function - result = workflow_engine._execute_workflow_function_once(workflow_engine._workflow_func, input_data) - + result = await workflow_engine._execute_workflow_function_once(workflow_func, input_data) + # Verify the result assert result == "processed: test-input" - def test_execute_workflow_function_async(self, workflow_engine): + @pytest.mark.asyncio + async def test_execute_workflow_function_async(self, workflow_engine): """Test asynchronous workflow function execution.""" async def async_workflow_func(input_data): return f"async-processed: {input_data}" @@ -226,33 +236,35 @@ async def async_workflow_func(input_data): input_data = "test-input" # Execute the async workflow function - result = workflow_engine._execute_workflow_function_once(async_workflow_func, input_data) + result = await workflow_engine._execute_workflow_function_once(async_workflow_func, input_data) # Verify the result assert result == "async-processed: test-input" - def test_execute_workflow_function_none(self, workflow_engine): + @pytest.mark.asyncio + async def test_execute_workflow_function_none(self, workflow_engine): """Test workflow function execution with None function.""" input_data = "test-input" # Execute with None workflow function - should raise TypeError with pytest.raises(TypeError, match="'NoneType' object is not callable"): - workflow_engine._execute_workflow_function_once(None, input_data) + await workflow_engine._execute_workflow_function_once(None, input_data) - def test_workflow_engine_initialization(self, workflow_engine, workflow_info, mock_client, mock_workflow_func): + def test_workflow_engine_initialization(self, workflow_engine, workflow_info, mock_client, mock_workflow_definition): """Test WorkflowEngine initialization.""" assert workflow_engine._context is not None - assert workflow_engine._workflow_func == mock_workflow_func + assert workflow_engine._workflow_definition == mock_workflow_definition + assert workflow_engine._workflow_instance is not None assert workflow_engine._decision_manager is not None assert workflow_engine._is_workflow_complete is False @pytest.mark.asyncio - async def test_workflow_engine_without_workflow_func(self, mock_client, workflow_info): - """Test WorkflowEngine without workflow function.""" + async def test_workflow_engine_without_workflow_definition(self, mock_client, workflow_info): + """Test WorkflowEngine without workflow definition.""" engine = WorkflowEngine( info=workflow_info, client=mock_client, - workflow_func=None + workflow_definition=None ) decision_task = self.create_mock_decision_task() @@ -269,12 +281,19 @@ async def test_workflow_engine_without_workflow_func(self, mock_client, workflow async def test_workflow_engine_workflow_completion(self, workflow_engine, mock_client): """Test workflow completion detection.""" decision_task = self.create_mock_decision_task() - - # Mock workflow function to return a result (indicating completion) - def completing_workflow_func(input_data): - return "workflow-completed" - - workflow_engine._workflow_func = completing_workflow_func + + # Create a workflow definition that returns a result (indicating completion) + class CompletingWorkflow: + @workflow.run + async def run(self, input_data): + return "workflow-completed" + + workflow_opts = WorkflowDefinitionOptions(name="completing_workflow") + completing_definition = WorkflowDefinition.wrap(CompletingWorkflow, workflow_opts) + + # Replace the workflow definition and instance + workflow_engine._workflow_definition = completing_definition + workflow_engine._workflow_instance = completing_definition.cls() with patch.object(workflow_engine._decision_manager, 'collect_pending_decisions', return_value=[]): # Process the decision diff --git a/tests/cadence/worker/test_decision_task_handler.py b/tests/cadence/worker/test_decision_task_handler.py index cd2b210..da2e79a 100644 --- a/tests/cadence/worker/test_decision_task_handler.py +++ b/tests/cadence/worker/test_decision_task_handler.py @@ -17,6 +17,8 @@ from cadence.worker._decision_task_handler import DecisionTaskHandler from cadence.worker._registry import Registry from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult +from cadence import workflow +from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions class TestDecisionTaskHandler: @@ -82,9 +84,15 @@ def test_initialization(self, mock_client, mock_registry): @pytest.mark.asyncio async def test_handle_task_implementation_success(self, handler, sample_decision_task, mock_registry): """Test successful decision task handling.""" - # Mock workflow function - mock_workflow_func = Mock() - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) @@ -142,9 +150,15 @@ async def test_handle_task_implementation_workflow_not_found(self, handler, samp @pytest.mark.asyncio async def test_handle_task_implementation_caches_engines(self, handler, sample_decision_task, mock_registry): """Test that decision task handler caches workflow engines for same workflow execution.""" - # Mock workflow function - mock_workflow_func = Mock() - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) @@ -172,9 +186,15 @@ async def test_handle_task_implementation_caches_engines(self, handler, sample_d @pytest.mark.asyncio async def test_handle_task_implementation_different_executions_get_separate_engines(self, handler, mock_registry): """Test that different workflow executions get separate engines.""" - # Mock workflow function - mock_workflow_func = Mock() - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Create two different decision tasks task1 = Mock(spec=PollForDecisionTaskResponse) @@ -323,19 +343,26 @@ async def test_respond_decision_task_completed_error(self, handler, sample_decis @pytest.mark.asyncio async def test_workflow_engine_creation_with_workflow_info(self, handler, sample_decision_task, mock_registry): """Test that WorkflowEngine is created with correct WorkflowInfo.""" - mock_workflow_func = Mock() - mock_registry.get_workflow.return_value = mock_workflow_func - + # Create actual workflow definition + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition + mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_workflow_engine_class: with patch('cadence.worker._decision_task_handler.WorkflowInfo') as mock_workflow_info_class: await handler._handle_task_implementation(sample_decision_task) - + # Verify WorkflowInfo was created with correct parameters (called once for engine) assert mock_workflow_info_class.call_count == 1 for call in mock_workflow_info_class.call_args_list: @@ -345,10 +372,10 @@ async def test_workflow_engine_creation_with_workflow_info(self, handler, sample 'workflow_id': "test_workflow_id", 'workflow_run_id': "test_run_id" } - + # Verify WorkflowEngine was created with correct parameters mock_workflow_engine_class.assert_called_once() call_args = mock_workflow_engine_class.call_args assert call_args[1]['info'] is not None assert call_args[1]['client'] == handler._client - assert call_args[1]['workflow_func'] == mock_workflow_func + assert call_args[1]['workflow_definition'] == workflow_definition diff --git a/tests/cadence/worker/test_decision_task_handler_integration.py b/tests/cadence/worker/test_decision_task_handler_integration.py index b513a14..fc65f0e 100644 --- a/tests/cadence/worker/test_decision_task_handler_integration.py +++ b/tests/cadence/worker/test_decision_task_handler_integration.py @@ -13,6 +13,7 @@ from cadence.api.v1.decision_pb2 import Decision from cadence.worker._decision_task_handler import DecisionTaskHandler from cadence.worker._registry import Registry +from cadence import workflow from cadence.client import Client @@ -35,12 +36,14 @@ def mock_client(self): def registry(self): """Create a registry with a test workflow.""" reg = Registry() - - @reg.workflow - def test_workflow(input_data): - """Simple test workflow that returns the input.""" - return f"processed: {input_data}" - + + @reg.workflow(name="test_workflow") + class TestWorkflow: + @workflow.run + async def run(self, input_data): + """Simple test workflow that returns the input.""" + return f"processed: {input_data}" + return reg @pytest.fixture diff --git a/tests/cadence/worker/test_decision_worker_integration.py b/tests/cadence/worker/test_decision_worker_integration.py index 85c55d2..18e970e 100644 --- a/tests/cadence/worker/test_decision_worker_integration.py +++ b/tests/cadence/worker/test_decision_worker_integration.py @@ -11,6 +11,7 @@ from cadence.api.v1.history_pb2 import History, HistoryEvent, WorkflowExecutionStartedEventAttributes from cadence.worker._decision import DecisionWorker from cadence.worker._registry import Registry +from cadence import workflow from cadence.client import Client @@ -34,12 +35,14 @@ def mock_client(self): def registry(self): """Create a registry with a test workflow.""" reg = Registry() - + @reg.workflow - def test_workflow(input_data): - """Simple test workflow that returns the input.""" - return f"processed: {input_data}" - + class TestWorkflow: + @workflow.run + async def run(self, input_data): + """Simple test workflow that returns the input.""" + return f"processed: {input_data}" + return reg @pytest.fixture @@ -236,8 +239,10 @@ async def test_decision_worker_with_different_workflow_types(self, decision_work """Test decision worker with different workflow types.""" # Add another workflow to the registry @registry.workflow - def another_workflow(input_data): - return f"another-processed: {input_data}" + class AnotherWorkflow: + @workflow.run + async def run(self, input_data): + return f"another-processed: {input_data}" # Create decision tasks for different workflow types task1 = self.create_mock_decision_task(workflow_type="test_workflow") diff --git a/tests/cadence/worker/test_registry.py b/tests/cadence/worker/test_registry.py index 4a8973b..bf6721e 100644 --- a/tests/cadence/worker/test_registry.py +++ b/tests/cadence/worker/test_registry.py @@ -6,7 +6,9 @@ import pytest from cadence import activity +from cadence import workflow from cadence.worker import Registry +from cadence.workflow import WorkflowDefinition from tests.cadence import common_activities @@ -21,24 +23,32 @@ def test_basic_registry_creation(self): with pytest.raises(KeyError): reg.get_activity("nonexistent") - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_basic_registration_and_retrieval(self, registration_type): - """Test basic registration and retrieval for both workflows and activities.""" + def test_basic_workflow_registration_and_retrieval(self): + """Test basic registration and retrieval for class-based workflows.""" reg = Registry() - - if registration_type == "workflow": - @reg.workflow - def test_func(): - return "test" - - func = reg.get_workflow("test_func") - else: - @reg.activity - def test_func(): + + @reg.workflow + class TestWorkflow: + @workflow.run + async def run(self): return "test" - - func = reg.get_activity(test_func.name) - + + # Registry stores WorkflowDefinition internally + workflow_def = reg.get_workflow("TestWorkflow") + # Verify it's actually a WorkflowDefinition + assert isinstance(workflow_def, WorkflowDefinition) + assert workflow_def.name == "TestWorkflow" + assert workflow_def.cls == TestWorkflow + + def test_basic_activity_registration_and_retrieval(self): + """Test basic registration and retrieval for activities.""" + reg = Registry() + + @reg.activity + def test_func(): + return "test" + + func = reg.get_activity(test_func.name) assert func() == "test" def test_direct_call_behavior(self): @@ -53,41 +63,47 @@ def test_func(): assert func() == "direct_call" - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_not_found_error(self, registration_type): - """Test KeyError is raised when function not found.""" + def test_workflow_not_found_error(self): + """Test KeyError is raised when workflow not found.""" reg = Registry() - - if registration_type == "workflow": - with pytest.raises(KeyError): - reg.get_workflow("nonexistent") - else: - with pytest.raises(KeyError): - reg.get_activity("nonexistent") - - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_duplicate_registration_error(self, registration_type): - """Test KeyError is raised for duplicate registrations.""" + with pytest.raises(KeyError): + reg.get_workflow("nonexistent") + + def test_activity_not_found_error(self): + """Test KeyError is raised when activity not found.""" reg = Registry() - - if registration_type == "workflow": - @reg.workflow - def test_func(): + with pytest.raises(KeyError): + reg.get_activity("nonexistent") + + def test_duplicate_workflow_registration_error(self): + """Test KeyError is raised for duplicate workflow registrations.""" + reg = Registry() + + @reg.workflow(name="duplicate_test") + class TestWorkflow: + @workflow.run + async def run(self): return "test" - - with pytest.raises(KeyError): - @reg.workflow - def test_func(): + + with pytest.raises(KeyError): + @reg.workflow(name="duplicate_test") + class TestWorkflow2: + @workflow.run + async def run(self): return "duplicate" - else: + + def test_duplicate_activity_registration_error(self): + """Test KeyError is raised for duplicate activity registrations.""" + reg = Registry() + + @reg.activity(name="test_func") + def test_func(): + return "test" + + with pytest.raises(KeyError): @reg.activity(name="test_func") def test_func(): - return "test" - - with pytest.raises(KeyError): - @reg.activity(name="test_func") - def test_func(): - return "duplicate" + return "duplicate" def test_register_activities_instance(self): reg = Registry() @@ -150,3 +166,40 @@ def test_of(self): assert result.get_activity("simple_fn") is not None assert result.get_activity("echo") is not None assert result.get_activity("async_fn") is not None + + def test_class_workflow_validation_errors(self): + """Test validation errors for class-based workflows.""" + reg = Registry() + + # Test missing run method + with pytest.raises(ValueError, match="No @workflow.run method found"): + @reg.workflow + class MissingRunWorkflow: + def some_method(self): + pass + + # Test duplicate run methods + with pytest.raises(ValueError, match="Multiple @workflow.run methods found"): + @reg.workflow + class DuplicateRunWorkflow: + @workflow.run + async def run1(self): + pass + + @workflow.run + async def run2(self): + pass + + def test_class_workflow_with_custom_name(self): + """Test class-based workflow with custom name.""" + reg = Registry() + + @reg.workflow(name="custom_workflow_name") + class CustomWorkflow: + @workflow.run + async def run(self, input: str) -> str: + return f"processed: {input}" + + workflow_def = reg.get_workflow("custom_workflow_name") + assert workflow_def.name == "custom_workflow_name" + assert workflow_def.cls == CustomWorkflow diff --git a/tests/cadence/worker/test_task_handler_integration.py b/tests/cadence/worker/test_task_handler_integration.py index 8e6aef9..daa36bb 100644 --- a/tests/cadence/worker/test_task_handler_integration.py +++ b/tests/cadence/worker/test_task_handler_integration.py @@ -12,6 +12,8 @@ from cadence.worker._decision_task_handler import DecisionTaskHandler from cadence.worker._registry import Registry from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult +from cadence import workflow +from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions class TestTaskHandlerIntegration: @@ -61,11 +63,17 @@ def sample_decision_task(self): @pytest.mark.asyncio async def test_full_task_handling_flow_success(self, handler, sample_decision_task, mock_registry): """Test the complete task handling flow from base handler through decision handler.""" - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + + + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) @@ -86,11 +94,15 @@ def mock_workflow_func(input_data): @pytest.mark.asyncio async def test_full_task_handling_flow_with_error(self, handler, sample_decision_task, mock_registry): """Test the complete task handling flow when an error occurs.""" - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine to raise an error mock_engine = Mock(spec=WorkflowEngine) @@ -110,11 +122,15 @@ def mock_workflow_func(input_data): @pytest.mark.asyncio async def test_context_activation_integration(self, handler, sample_decision_task, mock_registry): """Test that context activation works correctly in the integration.""" - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) @@ -144,11 +160,15 @@ def track_context_activation(): @pytest.mark.asyncio async def test_multiple_workflow_executions(self, handler, mock_registry): """Test handling multiple workflow executions creates new engines for each.""" - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Create multiple decision tasks for different workflows task1 = Mock(spec=PollForDecisionTaskResponse) @@ -194,11 +214,15 @@ def mock_workflow_func(input_data): @pytest.mark.asyncio async def test_workflow_engine_creation_integration(self, handler, sample_decision_task, mock_registry): """Test workflow engine creation integration.""" - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) @@ -218,11 +242,15 @@ def mock_workflow_func(input_data): @pytest.mark.asyncio async def test_error_handling_with_context_cleanup(self, handler, sample_decision_task, mock_registry): """Test that context cleanup happens even when errors occur.""" - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine to raise an error mock_engine = Mock(spec=WorkflowEngine) @@ -255,11 +283,15 @@ async def test_concurrent_task_handling(self, handler, mock_registry): """Test handling multiple tasks concurrently.""" import asyncio - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Create multiple tasks tasks = []