From edf65c0882c5b51858a4ba7cb59102941460e5d6 Mon Sep 17 00:00:00 2001 From: Nate Mortensen Date: Mon, 15 Sep 2025 11:15:01 -0700 Subject: [PATCH] Introduce ActivityDefinition as a wrapper for Activity Fns In order to use activity functions within a Workflow we need to have a very explicit name for the function that we can determine without the registry, as the activity may not even be present within the same registry as the workflow. Remove the concept of aliases, and mandate a decorator on all Activity functions. For ease of use, continue to support the decorator via the registry to both register and decorate the activity at the same time. In addition, rather than simply storing the activity function or setting attributes on it to track that it's registered, introduce a wrapping Callable. Within the context of a Workflow we can reinterpret invocations of the activity function as the execution of the activity itself. Signed-off-by: Nate Mortensen --- .../_internal/activity/_activity_executor.py | 13 +- cadence/_internal/activity/_context.py | 11 +- cadence/_internal/type_utils.py | 19 -- cadence/activity.py | 103 +++++++++- cadence/worker/__init__.py | 2 - cadence/worker/_registry.py | 139 ++++++++----- .../activity/test_activity_executor.py | 49 +++-- tests/cadence/_internal/test_type_utils.py | 70 ------- tests/cadence/common_activities.py | 51 +++++ tests/cadence/worker/test_registry.py | 185 +++++++----------- 10 files changed, 356 insertions(+), 286 deletions(-) delete mode 100644 cadence/_internal/type_utils.py delete mode 100644 tests/cadence/_internal/test_type_utils.py create mode 100644 tests/cadence/common_activities.py diff --git a/cadence/_internal/activity/_activity_executor.py b/cadence/_internal/activity/_activity_executor.py index f9efba0..e37e736 100644 --- a/cadence/_internal/activity/_activity_executor.py +++ b/cadence/_internal/activity/_activity_executor.py @@ -1,4 +1,3 @@ -import inspect from concurrent.futures import ThreadPoolExecutor from logging import getLogger from traceback import format_exception @@ -7,7 +6,7 @@ from google.protobuf.timestamp import to_datetime from cadence._internal.activity._context import _Context, _SyncContext -from cadence.activity import ActivityInfo +from cadence.activity import ActivityInfo, ActivityDefinition, ExecutionStrategy from cadence.api.v1.common_pb2 import Failure from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, RespondActivityTaskFailedRequest, \ RespondActivityTaskCompletedRequest @@ -16,7 +15,7 @@ _logger = getLogger(__name__) class ActivityExecutor: - def __init__(self, client: Client, task_list: str, identity: str, max_workers: int, registry: Callable[[str], Callable]): + def __init__(self, client: Client, task_list: str, identity: str, max_workers: int, registry: Callable[[str], ActivityDefinition]): self._client = client self._data_converter = client.data_converter self._registry = registry @@ -36,16 +35,16 @@ async def execute(self, task: PollForActivityTaskResponse): def _create_context(self, task: PollForActivityTaskResponse) -> _Context: activity_type = task.activity_type.name try: - activity_fn = self._registry(activity_type) + activity_def = self._registry(activity_type) except KeyError: raise KeyError(f"Activity type not found: {activity_type}") from None info = self._create_info(task) - if inspect.iscoroutinefunction(activity_fn): - return _Context(self._client, info, activity_fn) + if activity_def.strategy == ExecutionStrategy.ASYNC: + return _Context(self._client, info, activity_def) else: - return _SyncContext(self._client, info, activity_fn, self._thread_pool) + return _SyncContext(self._client, info, activity_def, self._thread_pool) async def _report_failure(self, task: PollForActivityTaskResponse, error: Exception): try: diff --git a/cadence/_internal/activity/_context.py b/cadence/_internal/activity/_context.py index 208b859..ce2f94b 100644 --- a/cadence/_internal/activity/_context.py +++ b/cadence/_internal/activity/_context.py @@ -1,15 +1,14 @@ import asyncio from concurrent.futures.thread import ThreadPoolExecutor -from typing import Callable, Any +from typing import Any from cadence import Client -from cadence._internal.type_utils import get_fn_parameters -from cadence.activity import ActivityInfo, ActivityContext +from cadence.activity import ActivityInfo, ActivityContext, ActivityDefinition from cadence.api.v1.common_pb2 import Payload class _Context(ActivityContext): - def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any]): + def __init__(self, client: Client, info: ActivityInfo, activity_fn: ActivityDefinition[[Any], Any]): self._client = client self._info = info self._activity_fn = activity_fn @@ -20,7 +19,7 @@ async def execute(self, payload: Payload) -> Any: return await self._activity_fn(*params) async def _to_params(self, payload: Payload) -> list[Any]: - type_hints = get_fn_parameters(self._activity_fn) + type_hints = [param.type_hint for param in self._activity_fn.params] return await self._client.data_converter.from_data(payload, type_hints) def client(self) -> Client: @@ -30,7 +29,7 @@ def info(self) -> ActivityInfo: return self._info class _SyncContext(_Context): - def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any], executor: ThreadPoolExecutor): + def __init__(self, client: Client, info: ActivityInfo, activity_fn: ActivityDefinition[[Any], Any], executor: ThreadPoolExecutor): super().__init__(client, info, activity_fn) self._executor = executor diff --git a/cadence/_internal/type_utils.py b/cadence/_internal/type_utils.py deleted file mode 100644 index 84fd07c..0000000 --- a/cadence/_internal/type_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -from inspect import signature, Parameter -from typing import Callable, List, Type, get_type_hints - -def get_fn_parameters(fn: Callable) -> List[Type | None]: - args = signature(fn).parameters - hints = get_type_hints(fn) - result = [] - for name, param in args.items(): - if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): - type_hint = hints.get(name, None) - result.append(type_hint) - - return result - -def validate_fn_parameters(fn: Callable) -> None: - args = signature(fn).parameters - for name, param in args.items(): - if param.kind not in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): - raise ValueError(f"Parameters must be positional. {name} is {param.kind}, and not valid") \ No newline at end of file diff --git a/cadence/activity.py b/cadence/activity.py index 0f71fb0..57a9b48 100644 --- a/cadence/activity.py +++ b/cadence/activity.py @@ -1,9 +1,14 @@ +import inspect from abc import ABC, abstractmethod from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass from datetime import timedelta, datetime -from typing import Iterator +from enum import Enum +from functools import update_wrapper +from inspect import signature, Parameter +from typing import Iterator, TypedDict, Unpack, Callable, Type, ParamSpec, TypeVar, Generic, get_type_hints, \ + Any, overload from cadence import Client @@ -59,3 +64,99 @@ def is_set() -> bool: @staticmethod def get() -> 'ActivityContext': return ActivityContext._var.get() + + +@dataclass(frozen=True) +class ActivityParameter: + name: str + type_hint: Type | None + default_value: Any | None + +class ExecutionStrategy(Enum): + ASYNC = "async" + THREAD_POOL = "thread_pool" + +class ActivityDefinitionOptions(TypedDict, total=False): + name: str + +P = ParamSpec('P') +T = TypeVar('T') + +class ActivityDefinition(Generic[P, T]): + def __init__(self, wrapped: Callable[P, T], name: str, strategy: ExecutionStrategy, params: list[ActivityParameter]): + self._wrapped = wrapped + self._name = name + self._strategy = strategy + self._params = params + update_wrapper(self, wrapped) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + return self._wrapped(*args, **kwargs) + + @property + def name(self) -> str: + return self._name + + @property + def strategy(self) -> ExecutionStrategy: + return self._strategy + + @property + def params(self) -> list[ActivityParameter]: + return self._params + + @staticmethod + def wrap(fn: Callable[P, T], opts: ActivityDefinitionOptions) -> 'ActivityDefinition[P, T]': + name = fn.__qualname__ + if "name" in opts and opts["name"]: + name = opts["name"] + + strategy = ExecutionStrategy.THREAD_POOL + if inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn.__call__): # type: ignore + strategy = ExecutionStrategy.ASYNC + + params = _get_params(fn) + return ActivityDefinition(fn, name, strategy, params) + + +ActivityDecorator = Callable[[Callable[P, T]], ActivityDefinition[P, T]] + +@overload +def defn(fn: Callable[P, T]) -> ActivityDefinition[P, T]: + ... + +@overload +def defn(**kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator: + ... + +def defn(fn: Callable[P, T] | None = None, **kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator | ActivityDefinition[P, T]: + options = ActivityDefinitionOptions(**kwargs) + def decorator(inner_fn: Callable[P, T]) -> ActivityDefinition[P, T]: + return ActivityDefinition.wrap(inner_fn, options) + + if fn is not None: + return decorator(fn) + + return decorator + + +def _get_params(fn: Callable) -> list[ActivityParameter]: + args = signature(fn).parameters + hints = get_type_hints(fn) + result = [] + for name, param in args.items(): + # "unbound functions" aren't a thing in the Python spec. Filter out the self parameter and hope they followed + # the convention. + if param.name == "self": + continue + default = None + if param.default != Parameter.empty: + default = param.default + if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): + type_hint = hints.get(name, None) + result.append(ActivityParameter(name, type_hint, default)) + + else: + raise ValueError(f"Parameters must be positional. {name} is {param.kind}, and not valid") + + return result diff --git a/cadence/worker/__init__.py b/cadence/worker/__init__.py index 6249d28..4084e9a 100644 --- a/cadence/worker/__init__.py +++ b/cadence/worker/__init__.py @@ -8,7 +8,6 @@ from ._registry import ( Registry, RegisterWorkflowOptions, - RegisterActivityOptions, ) __all__ = [ @@ -16,5 +15,4 @@ "WorkerOptions", 'Registry', 'RegisterWorkflowOptions', - 'RegisterActivityOptions', ] diff --git a/cadence/worker/_registry.py b/cadence/worker/_registry.py index 1f5d03f..d60521d 100644 --- a/cadence/worker/_registry.py +++ b/cadence/worker/_registry.py @@ -7,9 +7,8 @@ """ import logging -from typing import Callable, Dict, Optional, Unpack, TypedDict -from cadence._internal.type_utils import validate_fn_parameters - +from typing import Callable, Dict, Optional, Unpack, TypedDict, Sequence, overload +from cadence.activity import ActivityDefinitionOptions, ActivityDefinition, ActivityDecorator, P, T logger = logging.getLogger(__name__) @@ -19,13 +18,6 @@ class RegisterWorkflowOptions(TypedDict, total=False): name: Optional[str] alias: Optional[str] - -class RegisterActivityOptions(TypedDict, total=False): - """Options for registering an activity.""" - name: Optional[str] - alias: Optional[str] - - class Registry: """ Registry for managing workflows and activities. @@ -37,10 +29,9 @@ class Registry: def __init__(self) -> None: """Initialize the registry.""" self._workflows: Dict[str, Callable] = {} - self._activities: Dict[str, Callable] = {} + self._activities: Dict[str, ActivityDefinition] = {} self._workflow_aliases: Dict[str, str] = {} # alias -> name mapping - self._activity_aliases: Dict[str, str] = {} # alias -> name mapping - + def workflow( self, func: Optional[Callable] = None, @@ -84,12 +75,16 @@ def decorator(f: Callable) -> Callable: if func is None: return decorator return decorator(func) + + @overload + def activity(self, func: Callable[P, T]) -> ActivityDefinition[P, T]: + ... + + @overload + def activity(self, **kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator: + ... - def activity( - self, - func: Optional[Callable] = None, - **kwargs: Unpack[RegisterActivityOptions] - ) -> Callable: + def activity(self, func: Callable[P, T] | None = None, **kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator | ActivityDefinition[P, T]: """ Register an activity function. @@ -105,30 +100,40 @@ def activity( Raises: KeyError: If activity name already exists """ - options = RegisterActivityOptions(**kwargs) - - def decorator(f: Callable) -> Callable: - validate_fn_parameters(f) - activity_name = options.get('name') or f.__name__ - - if activity_name in self._activities: - raise KeyError(f"Activity '{activity_name}' is already registered") - - self._activities[activity_name] = f - - # Register alias if provided - alias = options.get('alias') - if alias: - if alias in self._activity_aliases: - raise KeyError(f"Activity alias '{alias}' is already registered") - self._activity_aliases[alias] = activity_name - - logger.info(f"Registered activity '{activity_name}'") - return f - - if func is None: - return decorator - return decorator(func) + options = ActivityDefinitionOptions(**kwargs) + + def decorator(f: Callable[P, T]) -> ActivityDefinition[P, T]: + defn = ActivityDefinition.wrap(f, options) + + self._register_activity(defn) + + return defn + + if func is not None: + return decorator(func) + + return decorator + + def register_activities(self, obj: object) -> None: + activities = _find_activity_definitions(obj) + if not activities: + raise ValueError(f"No activity definitions found in '{repr(obj)}'") + + for defn in activities: + self._register_activity(defn) + + + def register_activity(self, defn: Callable) -> None: + if not isinstance(defn, ActivityDefinition): + raise ValueError(f"{defn.__qualname__} must have @activity.defn decorator") + self._register_activity(defn) + + def _register_activity(self, defn: ActivityDefinition) -> None: + if defn.name in self._activities: + raise KeyError(f"Activity '{defn.name}' is already registered") + + self._activities[defn.name] = defn + def get_workflow(self, name: str) -> Callable: """ @@ -151,7 +156,7 @@ def get_workflow(self, name: str) -> Callable: return self._workflows[actual_name] - def get_activity(self, name: str) -> Callable: + def get_activity(self, name: str) -> ActivityDefinition: """ Get a registered activity by name. @@ -164,13 +169,45 @@ def get_activity(self, name: str) -> Callable: Raises: KeyError: If activity is not found """ - # Check if it's an alias - actual_name = self._activity_aliases.get(name, name) - - if actual_name not in self._activities: - raise KeyError(f"Activity '{name}' not found in registry") - - return self._activities[actual_name] - + return self._activities[name] + + def __add__(self, other: 'Registry') -> 'Registry': + result = Registry() + for name, fn in self._activities.items(): + result._register_activity(fn) + for name, fn in other._activities.items(): + result._register_activity(fn) + + return result + + @staticmethod + def of(*args: 'Registry') -> 'Registry': + result = Registry() + for other in args: + result += other + + return result + +def _find_activity_definitions(instance: object) -> Sequence[ActivityDefinition]: + attr_to_def = {} + for t in instance.__class__.__mro__: + for attr in dir(t): + if attr.startswith("_"): + continue + value = getattr(t, attr) + if isinstance(value, ActivityDefinition): + if attr in attr_to_def: + raise ValueError(f"'{attr}' was overridden with a duplicate activity definition") + attr_to_def[attr] = value + + # Create new definitions, copying the attributes from the declaring type but using the function + # from the specific object. This allows for the decorator to be applied to the base class and the + # function to be overridden + result = [] + for attr, definition in attr_to_def.items(): + result.append(ActivityDefinition(getattr(instance, attr), definition.name, definition.strategy, definition.params)) + + return result + \ No newline at end of file diff --git a/tests/cadence/_internal/activity/test_activity_executor.py b/tests/cadence/_internal/activity/test_activity_executor.py index 89b95e2..d6aba4d 100644 --- a/tests/cadence/_internal/activity/test_activity_executor.py +++ b/tests/cadence/_internal/activity/test_activity_executor.py @@ -8,11 +8,12 @@ from cadence import activity, Client from cadence._internal.activity import ActivityExecutor -from cadence.activity import ActivityInfo +from cadence.activity import ActivityInfo, ActivityDefinition from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure, WorkflowType from cadence.api.v1.service_worker_pb2 import RespondActivityTaskCompletedResponse, PollForActivityTaskResponse, \ RespondActivityTaskCompletedRequest, RespondActivityTaskFailedResponse, RespondActivityTaskFailedRequest from cadence.data_converter import DefaultDataConverter +from cadence.worker import Registry @pytest.fixture @@ -27,12 +28,14 @@ async def test_activity_async_success(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + reg = Registry() + @reg.activity(name="activity_type") async def activity_fn(): return "success" - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) - await executor.execute(fake_task("any", "")) + await executor.execute(fake_task("activity_type", "")) worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( task_token=b'task_token', @@ -44,12 +47,14 @@ async def test_activity_async_failure(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) + reg = Registry() + @reg.activity(name="activity_type") async def activity_fn(): raise KeyError("failure") - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) - await executor.execute(fake_task("any", "")) + await executor.execute(fake_task("activity_type", "")) worker_stub.RespondActivityTaskFailed.assert_called_once() @@ -70,12 +75,14 @@ async def test_activity_args(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + reg = Registry() + @reg.activity(name="activity_type") async def activity_fn(first: str, second: str): return " ".join([first, second]) - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) - await executor.execute(fake_task("any", '["hello", "world"]')) + await executor.execute(fake_task("activity_type", '["hello", "world"]')) worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( task_token=b'task_token', @@ -87,6 +94,8 @@ async def test_activity_sync_success(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + reg = Registry() + @reg.activity(name="activity_type") def activity_fn(): try: asyncio.get_running_loop() @@ -94,9 +103,9 @@ def activity_fn(): return "success" raise RuntimeError("expected to be running outside of the event loop") - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) - await executor.execute(fake_task("any", "")) + await executor.execute(fake_task("activity_type", "")) worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( task_token=b'task_token', @@ -107,13 +116,14 @@ def activity_fn(): async def test_activity_sync_failure(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) - + reg = Registry() + @reg.activity(name="activity_type") def activity_fn(): raise KeyError("failure") - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) - await executor.execute(fake_task("any", "")) + await executor.execute(fake_task("activity_type", "")) worker_stub.RespondActivityTaskFailed.assert_called_once() @@ -134,18 +144,18 @@ async def test_activity_unknown(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) - def registry(name: str): + def registry(name: str) -> ActivityDefinition: raise KeyError(f"unknown activity: {name}") executor = ActivityExecutor(client, 'task_list', 'identity', 1, registry) - await executor.execute(fake_task("any", "")) + await executor.execute(fake_task("activity_type", "")) worker_stub.RespondActivityTaskFailed.assert_called_once() call = worker_stub.RespondActivityTaskFailed.call_args[0][0] - assert 'Activity type not found: any' in call.failure.details.decode() + assert 'Activity type not found: activity_type' in call.failure.details.decode() call.failure.details = bytes() assert call == RespondActivityTaskFailedRequest( task_token=b'task_token', @@ -158,14 +168,15 @@ def registry(name: str): async def test_activity_context(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) - + reg = Registry() + @reg.activity(name="activity_type") async def activity_fn(): assert fake_info("activity_type") == activity.info() assert activity.in_activity() assert activity.client() is not None return "success" - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) await executor.execute(fake_task("activity_type", "")) @@ -179,6 +190,8 @@ async def test_activity_context_sync(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + reg = Registry() + @reg.activity(name="activity_type") def activity_fn(): assert fake_info("activity_type") == activity.info() assert activity.in_activity() @@ -186,7 +199,7 @@ def activity_fn(): activity.client() return "success" - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) await executor.execute(fake_task("activity_type", "")) diff --git a/tests/cadence/_internal/test_type_utils.py b/tests/cadence/_internal/test_type_utils.py deleted file mode 100644 index 9e35e81..0000000 --- a/tests/cadence/_internal/test_type_utils.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import Callable, Type - -import pytest - -from cadence._internal.type_utils import get_fn_parameters, validate_fn_parameters - - -def _single_param(name: str): - ... - -def _multiple_param(name: str, other: 'str'): - ... - -def _with_args(name:str, *args): - ... - -def _with_kwargs(name:str, **kwargs): - ... - -def _strictly_positional(name: str, other: str, *args, **kwargs): - ... - -def _keyword_only(*args, foo: str): - ... - - -@pytest.mark.parametrize( - "fn,expected", - [ - pytest.param( - _single_param, [str], id="single param" - ), - pytest.param( - _multiple_param, [str, str], id="multiple param" - ), - pytest.param( - _strictly_positional, [str, str], id="strictly positional" - ), - pytest.param( - _keyword_only, [], id="keyword only" - ), - ] -) -def test_get_fn_parameters(fn: Callable, expected: list[Type]): - params = get_fn_parameters(fn) - assert params == expected - -@pytest.mark.parametrize( - "fn,expected", - [ - pytest.param( - _single_param, None, id="single param" - ), - pytest.param( - _multiple_param, None, id="multiple param" - ), - pytest.param( - _with_args, ValueError, id="with args" - ), - pytest.param( - _with_kwargs, ValueError, id="with kwargs" - ), - ] -) -def test_validate_fn_parameters(fn: Callable, expected: Type[Exception]): - if expected: - with pytest.raises(expected): - validate_fn_parameters(fn) - else: - validate_fn_parameters(fn) \ No newline at end of file diff --git a/tests/cadence/common_activities.py b/tests/cadence/common_activities.py new file mode 100644 index 0000000..be78c62 --- /dev/null +++ b/tests/cadence/common_activities.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass + +from cadence import activity + + +@activity.defn() +def simple_fn() -> None: + pass + +@activity.defn +def no_parens() -> None: + pass + +@activity.defn() +def echo(incoming: str) -> str: + return incoming + +@activity.defn(name="renamed") +def renamed_fn() -> None: + pass + +@activity.defn() +async def async_fn() -> None: + pass + +class Activities: + + @activity.defn() + def echo_sync(self, incoming: str) -> str: + return incoming + + @activity.defn() + async def echo_async(self, incoming: str) -> str: + return incoming + +class ActivityInterface: + @activity.defn() + def do_something(self) -> str: + ... + +@dataclass +class ActivityImpl(ActivityInterface): + result: str + + def do_something(self) -> str: + return self.result + +class InvalidImpl(ActivityInterface): + @activity.defn(name="something else entirely") + def do_something(self) -> str: + return "hehe" \ No newline at end of file diff --git a/tests/cadence/worker/test_registry.py b/tests/cadence/worker/test_registry.py index 57f345b..4a8973b 100644 --- a/tests/cadence/worker/test_registry.py +++ b/tests/cadence/worker/test_registry.py @@ -5,7 +5,9 @@ import pytest -from cadence.worker import Registry, RegisterWorkflowOptions, RegisterActivityOptions +from cadence import activity +from cadence.worker import Registry +from tests.cadence import common_activities class TestRegistry: @@ -35,74 +37,22 @@ def test_func(): def test_func(): return "test" - func = reg.get_activity("test_func") + func = reg.get_activity(test_func.name) assert func() == "test" - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_direct_call_behavior(self, registration_type): - """Test direct function call behavior for both workflows and activities.""" + def test_direct_call_behavior(self): reg = Registry() - + + @activity.defn(name="test_func") def test_func(): return "direct_call" + + reg.register_activity(test_func) + func = reg.get_activity("test_func") - if registration_type == "workflow": - registered_func = reg.workflow(test_func) - func = reg.get_workflow("test_func") - else: - registered_func = reg.activity(test_func) - func = reg.get_activity("test_func") - - assert registered_func == test_func assert func() == "direct_call" - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_decorator_with_options(self, registration_type): - """Test decorator with options for both workflows and activities.""" - reg = Registry() - - if registration_type == "workflow": - @reg.workflow(name="custom_name", alias="custom_alias") - def test_func(): - return "decorator_with_options" - - func = reg.get_workflow("custom_name") - func_by_alias = reg.get_workflow("custom_alias") - else: - @reg.activity(name="custom_name", alias="custom_alias") - def test_func(): - return "decorator_with_options" - - func = reg.get_activity("custom_name") - func_by_alias = reg.get_activity("custom_alias") - - assert func() == "decorator_with_options" - assert func_by_alias() == "decorator_with_options" - assert func == func_by_alias - - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_direct_call_with_options(self, registration_type): - """Test direct call with options for both workflows and activities.""" - reg = Registry() - - def test_func(): - return "direct_call_with_options" - - if registration_type == "workflow": - registered_func = reg.workflow(test_func, name="custom_name", alias="custom_alias") - func = reg.get_workflow("custom_name") - func_by_alias = reg.get_workflow("custom_alias") - else: - registered_func = reg.activity(test_func, name="custom_name", alias="custom_alias") - func = reg.get_activity("custom_name") - func_by_alias = reg.get_activity("custom_name") - - assert registered_func == test_func - assert func() == "direct_call_with_options" - assert func_by_alias() == "direct_call_with_options" - assert func == func_by_alias - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) def test_not_found_error(self, registration_type): """Test KeyError is raised when function not found.""" @@ -130,62 +80,73 @@ def test_func(): def test_func(): return "duplicate" else: - @reg.activity + @reg.activity(name="test_func") def test_func(): return "test" - + with pytest.raises(KeyError): - @reg.activity + @reg.activity(name="test_func") def test_func(): return "duplicate" - - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_alias_functionality(self, registration_type): - """Test alias functionality for both workflows and activities.""" + + def test_register_activities_instance(self): reg = Registry() - - if registration_type == "workflow": - @reg.workflow(name="custom_name") - def test_func(): - return "test" - - func = reg.get_workflow("custom_name") - else: - @reg.activity(alias="custom_alias") - def test_func(): - return "test" - - func = reg.get_activity("custom_alias") - func_by_name = reg.get_activity("test_func") - assert func_by_name() == "test" - assert func == func_by_name - - assert func() == "test" - - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_options_class(self, registration_type): - """Test using options classes for both workflows and activities.""" + + reg.register_activities(common_activities.Activities()) + + assert reg.get_activity("Activities.echo_sync") is not None + assert reg.get_activity("Activities.echo_sync") is not None + + def test_register_activities_interface(self): + impl = common_activities.ActivityImpl("result") reg = Registry() - - if registration_type == "workflow": - options = RegisterWorkflowOptions(name="custom_name", alias="custom_alias") - - @reg.workflow(**options) - def test_func(): - return "test" - - func = reg.get_workflow("custom_name") - func_by_alias = reg.get_workflow("custom_alias") - else: - options = RegisterActivityOptions(name="custom_name", alias="custom_alias") - - @reg.activity(**options) - def test_func(): - return "test" - - func = reg.get_activity("custom_name") - func_by_alias = reg.get_activity("custom_alias") - - assert func() == "test" - assert func_by_alias() == "test" - assert func == func_by_alias + + reg.register_activities(impl) + + assert reg.get_activity(common_activities.ActivityInterface.do_something.name) is not None + assert reg.get_activity("ActivityInterface.do_something") is not None + assert reg.get_activity(common_activities.ActivityInterface.do_something.name)() == "result" + + def test_register_activities_invalid_impl(self): + impl = common_activities.InvalidImpl() + reg = Registry() + + with pytest.raises(ValueError): + reg.register_activities(impl) + + + def test_add(self): + registry = Registry() + registry.register_activity(common_activities.simple_fn) + other = Registry() + other.register_activity(common_activities.echo) + + result = registry + other + + assert result.get_activity("simple_fn") is not None + assert result.get_activity("echo") is not None + with pytest.raises(KeyError): + registry.get_activity("echo") + with pytest.raises(KeyError): + other.get_activity("simple_fn") + + def test_add_duplicate(self): + registry = Registry() + registry.register_activity(common_activities.simple_fn) + other = Registry() + other.register_activity(common_activities.simple_fn) + with pytest.raises(KeyError): + registry + other + + def test_of(self): + first = Registry() + second = Registry() + third = Registry() + first.register_activity(common_activities.simple_fn) + second.register_activity(common_activities.echo) + third.register_activity(common_activities.async_fn) + + result = Registry.of(first, second, third) + assert result.get_activity("simple_fn") is not None + assert result.get_activity("echo") is not None + assert result.get_activity("async_fn") is not None