diff --git a/examples/workflow-async/README.md b/examples/workflow-async/README.md new file mode 100644 index 00000000..4ce670df --- /dev/null +++ b/examples/workflow-async/README.md @@ -0,0 +1,15 @@ +# Dapr Workflow Async Examples (Python) + +These examples mirror `examples/workflow/` but author orchestrators with `async def` using the +async workflow APIs. Activities remain regular functions unless noted. + +How to run: +- Ensure a Dapr sidecar is running locally. If needed, set `DURABLETASK_GRPC_ENDPOINT`, or + `DURABLETASK_GRPC_HOST/PORT`. +- Install requirements: `pip install -r requirements.txt` +- Run any example: `python simple.py` + +Notes: +- Orchestrators use `await ctx.activity(...)`, `await ctx.sleep(...)`, `await ctx.when_all/when_any(...)`, etc. +- No event loop is started manually; the Durable Task worker drives the async orchestrators. +- You can also launch instances using `DaprWorkflowClient` as in the non-async examples. diff --git a/examples/workflow-async/child_workflow.py b/examples/workflow-async/child_workflow.py new file mode 100644 index 00000000..9e89a752 --- /dev/null +++ b/examples/workflow-async/child_workflow.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowRuntime, +) + +wfr = WorkflowRuntime() + + +@wfr.async_workflow(name='child_async') +async def child(ctx: AsyncWorkflowContext, n: int) -> int: + return n * 2 + + +@wfr.async_workflow(name='parent_async') +async def parent(ctx: AsyncWorkflowContext, n: int) -> int: + r = await ctx.call_child_workflow(child, input=n) + print(f'Child workflow returned {r}') + return r + 1 + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'parent_async_instance' + client.schedule_new_workflow(workflow=parent, input=5, instance_id=instance_id) + client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/fan_out_fan_in.py b/examples/workflow-async/fan_out_fan_in.py new file mode 100644 index 00000000..16e6e48d --- /dev/null +++ b/examples/workflow-async/fan_out_fan_in.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowActivityContext, + WorkflowRuntime, +) + +wfr = WorkflowRuntime() + + +@wfr.activity(name='square') +def square(ctx: WorkflowActivityContext, x: int) -> int: + return x * x + + +@wfr.async_workflow(name='fan_out_fan_in_async') +async def orchestrator(ctx: AsyncWorkflowContext): + tasks = [ctx.call_activity(square, input=i) for i in range(1, 6)] + results = await ctx.when_all(tasks) + total = sum(results) + return total + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'fofi_async' + client.schedule_new_workflow(workflow=orchestrator, instance_id=instance_id) + wf_state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + print(f'Workflow state: {wf_state}') + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/human_approval.py b/examples/workflow-async/human_approval.py new file mode 100644 index 00000000..7ce17722 --- /dev/null +++ b/examples/workflow-async/human_approval.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import AsyncWorkflowContext, DaprWorkflowClient, WorkflowRuntime + +wfr = WorkflowRuntime() + + +@wfr.async_workflow(name='human_approval_async') +async def orchestrator(ctx: AsyncWorkflowContext, request_id: str): + decision = await ctx.when_any( + [ + ctx.wait_for_external_event(f'approve:{request_id}'), + ctx.wait_for_external_event(f'reject:{request_id}'), + ctx.create_timer(300.0), + ] + ) + if isinstance(decision, dict) and decision.get('approved'): + return 'APPROVED' + if isinstance(decision, dict) and decision.get('rejected'): + return 'REJECTED' + return 'TIMEOUT' + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'human_approval_async_1' + client.schedule_new_workflow(workflow=orchestrator, input='REQ-1', instance_id=instance_id) + # In a real scenario, raise approve/reject event from another service. + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/requirements.txt b/examples/workflow-async/requirements.txt new file mode 100644 index 00000000..e220036d --- /dev/null +++ b/examples/workflow-async/requirements.txt @@ -0,0 +1,2 @@ +dapr-ext-workflow-dev>=1.15.0.dev +dapr-dev>=1.15.0.dev diff --git a/examples/workflow-async/simple.py b/examples/workflow-async/simple.py new file mode 100644 index 00000000..1e7cf039 --- /dev/null +++ b/examples/workflow-async/simple.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from datetime import timedelta +from time import sleep + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + RetryPolicy, + WorkflowActivityContext, + WorkflowRuntime, +) + +counter = 0 +retry_count = 0 +child_orchestrator_string = '' +instance_id = 'asyncExampleInstanceID' +child_instance_id = 'asyncChildInstanceID' +workflow_name = 'async_hello_world_wf' +child_workflow_name = 'async_child_wf' +input_data = 'Hi Async Counter!' +event_name = 'event1' +event_data = 'eventData' + +retry_policy = RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=100), +) + +wfr = WorkflowRuntime() + + +@wfr.async_workflow(name=workflow_name) +async def hello_world_wf(ctx: AsyncWorkflowContext, wf_input): + # activities + result_1 = await ctx.call_activity(hello_act, input=1) + print(f'Activity 1 returned {result_1}') + result_2 = await ctx.call_activity(hello_act, input=10) + print(f'Activity 2 returned {result_2}') + result_3 = await ctx.call_activity(hello_retryable_act, retry_policy=retry_policy) + print(f'Activity 3 returned {result_3}') + result_4 = await ctx.call_child_workflow(child_retryable_wf, retry_policy=retry_policy) + print(f'Child workflow returned {result_4}') + + # Event vs timeout using when_any + first = await ctx.when_any( + [ + ctx.wait_for_external_event(event_name), + ctx.create_timer(timedelta(seconds=30)), + ] + ) + + # Proceed only if event won + if isinstance(first, dict) and 'event' in first: + await ctx.call_activity(hello_act, input=100) + await ctx.call_activity(hello_act, input=1000) + return 'Completed' + return 'Timeout' + + +@wfr.activity(name='async_hello_act') +def hello_act(ctx: WorkflowActivityContext, wf_input): + global counter + counter += wf_input + return f'Activity returned {wf_input}' + + +@wfr.activity(name='async_hello_retryable_act') +def hello_retryable_act(ctx: WorkflowActivityContext): + global retry_count + if (retry_count % 2) == 0: + retry_count += 1 + raise ValueError('Retryable Error') + retry_count += 1 + return f'Activity returned {retry_count}' + + +@wfr.async_workflow(name=child_workflow_name) +async def child_retryable_wf(ctx: AsyncWorkflowContext): + # Call activity with retry and simulate retryable workflow failure until certain state + child_activity_result = await ctx.call_activity( + act_for_child_wf, input='x', retry_policy=retry_policy + ) + print(f'Child activity returned {child_activity_result}') + # In a real sample, you might check state and raise to trigger retry + return 'ok' + + +@wfr.activity(name='async_act_for_child_wf') +def act_for_child_wf(ctx: WorkflowActivityContext, inp): + global child_orchestrator_string + child_orchestrator_string += inp + + +def main(): + wfr.start() + wf_client = DaprWorkflowClient() + + wf_client.schedule_new_workflow( + workflow=hello_world_wf, input=input_data, instance_id=instance_id + ) + + wf_client.wait_for_workflow_start(instance_id) + + # Let initial activities run + sleep(5) + + # Raise event to continue + wf_client.raise_workflow_event( + instance_id=instance_id, event_name=event_name, data={'ok': True} + ) + + # Wait for completion + state = wf_client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + print(f'Workflow status: {state.runtime_status.name}') + + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/task_chaining.py b/examples/workflow-async/task_chaining.py new file mode 100644 index 00000000..c9c92add --- /dev/null +++ b/examples/workflow-async/task_chaining.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowActivityContext, + WorkflowRuntime, +) + +wfr = WorkflowRuntime() + + +@wfr.activity(name='sum') +def sum_act(ctx: WorkflowActivityContext, nums): + return sum(nums) + + +@wfr.async_workflow(name='task_chaining_async') +async def orchestrator(ctx: AsyncWorkflowContext): + a = await ctx.call_activity(sum_act, input=[1, 2]) + b = await ctx.call_activity(sum_act, input=[a, 3]) + c = await ctx.call_activity(sum_act, input=[b, 4]) + return c + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'task_chain_async' + client.schedule_new_workflow(workflow=orchestrator, instance_id=instance_id) + client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/README.md b/examples/workflow/README.md index 2e09eeef..3cb71f0d 100644 --- a/examples/workflow/README.md +++ b/examples/workflow/README.md @@ -12,6 +12,8 @@ This directory contains examples of using the [Dapr Workflow](https://docs.dapr. You can install dapr SDK package using pip command: ```sh +python3 -m venv .venv +source .venv/bin/activate pip3 install -r requirements.txt ``` diff --git a/examples/workflow/aio/async_activity_sequence.py b/examples/workflow/aio/async_activity_sequence.py new file mode 100644 index 00000000..8eecd1f8 --- /dev/null +++ b/examples/workflow/aio/async_activity_sequence.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import AsyncWorkflowContext, WorkflowRuntime + + +def main(): + rt = WorkflowRuntime() + + @rt.activity(name='add') + def add(ctx, xy): + return xy[0] + xy[1] + + @rt.workflow(name='sum_three') + async def sum_three(ctx: AsyncWorkflowContext, nums): + a = await ctx.call_activity(add, input=[nums[0], nums[1]]) + b = await ctx.call_activity(add, input=[a, nums[2]]) + return b + + rt.start() + print("Registered async workflow 'sum_three' and activity 'add'") + + # This example registers only; use Dapr client to start instances externally. + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/aio/async_external_event.py b/examples/workflow/aio/async_external_event.py new file mode 100644 index 00000000..90531422 --- /dev/null +++ b/examples/workflow/aio/async_external_event.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import AsyncWorkflowContext, WorkflowRuntime + + +def main(): + rt = WorkflowRuntime() + + @rt.async_workflow(name='wait_event') + async def wait_event(ctx: AsyncWorkflowContext): + data = await ctx.wait_for_external_event('go') + return {'event': data} + + rt.start() + print("Registered async workflow 'wait_event'") + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/aio/async_sub_orchestrator.py b/examples/workflow/aio/async_sub_orchestrator.py new file mode 100644 index 00000000..c00d9ca9 --- /dev/null +++ b/examples/workflow/aio/async_sub_orchestrator.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import AsyncWorkflowContext, WorkflowRuntime + + +def main(): + rt = WorkflowRuntime() + + @rt.async_workflow(name='child') + async def child(ctx: AsyncWorkflowContext, n): + return n * 2 + + @rt.async_workflow(name='parent') + async def parent(ctx: AsyncWorkflowContext, n): + r = await ctx.call_child_workflow(child, input=n) + return r + 1 + + rt.start() + print("Registered async workflows 'parent' and 'child'") + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/aio/context_interceptors_example.py b/examples/workflow/aio/context_interceptors_example.py new file mode 100644 index 00000000..d005bca1 --- /dev/null +++ b/examples/workflow/aio/context_interceptors_example.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- + +""" +Example: Interceptors for context propagation (client + runtime). + +This example shows how to: + - Define a small context (dict) carried via contextvars + - Implement ClientInterceptor to inject that context into outbound inputs + - Implement RuntimeInterceptor to restore the context before user code runs + - Wire interceptors into WorkflowRuntime and DaprWorkflowClient + +Note: Scheduling/running requires a Dapr sidecar. This file focuses on the wiring pattern. +""" + +from __future__ import annotations + +import contextvars +from typing import Any, Callable + +from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityInput, + CallChildWorkflowInput, + DaprWorkflowClient, + ExecuteActivityInput, + ExecuteWorkflowInput, + ScheduleWorkflowInput, + WorkflowRuntime, +) + +# A simple context carried across boundaries +_current_ctx: contextvars.ContextVar[dict[str, Any] | None] = contextvars.ContextVar( + 'wf_ctx', default=None +) + + +def set_ctx(ctx: dict[str, Any] | None) -> None: + _current_ctx.set(ctx) + + +def get_ctx() -> dict[str, Any] | None: + return _current_ctx.get() + + +def _merge_ctx(args: Any) -> Any: + ctx = get_ctx() + if ctx and isinstance(args, dict) and 'context' not in args: + return {**args, 'context': ctx} + return args + + +class ContextClientInterceptor(BaseClientInterceptor): + def schedule_new_workflow( + self, input: ScheduleWorkflowInput, nxt: Callable[[ScheduleWorkflowInput], Any] + ) -> Any: # type: ignore[override] + input = ScheduleWorkflowInput( + workflow_name=input.workflow_name, + args=_merge_ctx(input.args), + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + ) + return nxt(input) + + +class ContextWorkflowOutboundInterceptor(BaseWorkflowOutboundInterceptor): + def call_child_workflow( + self, input: CallChildWorkflowInput, nxt: Callable[[CallChildWorkflowInput], Any] + ) -> Any: + return nxt( + CallChildWorkflowInput( + workflow_name=input.workflow_name, + args=_merge_ctx(input.args), + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + local_context=input.local_context, + ) + ) + + def call_activity( + self, input: CallActivityInput, nxt: Callable[[CallActivityInput], Any] + ) -> Any: + return nxt( + CallActivityInput( + activity_name=input.activity_name, + args=_merge_ctx(input.args), + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + local_context=input.local_context, + ) + ) + + +class ContextRuntimeInterceptor(BaseRuntimeInterceptor): + def execute_workflow( + self, input: ExecuteWorkflowInput, nxt: Callable[[ExecuteWorkflowInput], Any] + ) -> Any: # type: ignore[override] + if isinstance(input.input, dict) and 'context' in input.input: + set_ctx(input.input['context']) + try: + return nxt(input) + finally: + set_ctx(None) + + def execute_activity( + self, input: ExecuteActivityInput, nxt: Callable[[ExecuteActivityInput], Any] + ) -> Any: # type: ignore[override] + if isinstance(input.input, dict) and 'context' in input.input: + set_ctx(input.input['context']) + try: + return nxt(input) + finally: + set_ctx(None) + + +# Example workflow and activity +def activity_log(ctx, data: dict[str, Any]) -> str: # noqa: ANN001 (example) + # Access restored context inside activity via contextvars + return f'ok:{get_ctx()}' + + +def workflow_example(ctx, x: int): # noqa: ANN001 (example) + y = yield ctx.call_activity(activity_log, input={'msg': 'hello'}) + return y + + +def wire_up() -> tuple[WorkflowRuntime, DaprWorkflowClient]: + runtime = WorkflowRuntime( + runtime_interceptors=[ContextRuntimeInterceptor()], + workflow_outbound_interceptors=[ContextWorkflowOutboundInterceptor()], + ) + client = DaprWorkflowClient(interceptors=[ContextClientInterceptor()]) + + # Register workflow/activity + runtime.workflow(name='example')(workflow_example) + runtime.activity(name='activity_log')(activity_log) + return runtime, client + + +if __name__ == '__main__': + # This section demonstrates how you would set a context and schedule a workflow. + # Requires a running Dapr sidecar to actually execute. + rt, cli = wire_up() + set_ctx({'tenant': 'acme', 'request_id': 'r-123'}) + # instance_id = cli.schedule_new_workflow(workflow_example, input={'x': 1}) + # print('scheduled:', instance_id) + # rt.start(); rt.wait_for_ready(); ... + pass diff --git a/examples/workflow/aio/model_tool_serialization_example.py b/examples/workflow/aio/model_tool_serialization_example.py new file mode 100644 index 00000000..2c1bdf4c --- /dev/null +++ b/examples/workflow/aio/model_tool_serialization_example.py @@ -0,0 +1,66 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from typing import Any, Dict + +from dapr.ext.workflow import ensure_canonical_json + +""" +Example of implementing provider-specific model/tool serialization OUTSIDE the core package. + +This demonstrates how to build and use your own contracts using the generic helpers from +`dapr.ext.workflow.serializers`. +""" + + +def to_model_request(payload: Dict[str, Any]) -> Dict[str, Any]: + req = { + 'schema_version': 'model_req@v1', + 'model_name': payload.get('model_name'), + 'system_instructions': payload.get('system_instructions'), + 'input': payload.get('input'), + 'model_settings': payload.get('model_settings') or {}, + 'tools': payload.get('tools') or [], + } + return ensure_canonical_json(req, strict=True) + + +def from_model_response(obj: Any) -> Dict[str, Any]: + if isinstance(obj, dict): + content = obj.get('content') + tool_calls = obj.get('tool_calls') or [] + out = {'schema_version': 'model_res@v1', 'content': content, 'tool_calls': tool_calls} + return ensure_canonical_json(out, strict=False) + return ensure_canonical_json( + {'schema_version': 'model_res@v1', 'content': str(obj), 'tool_calls': []}, strict=False + ) + + +def to_tool_request(name: str, args: list | None, kwargs: dict | None) -> Dict[str, Any]: + req = { + 'schema_version': 'tool_req@v1', + 'tool_name': name, + 'args': args or [], + 'kwargs': kwargs or {}, + } + return ensure_canonical_json(req, strict=True) + + +def from_tool_result(obj: Any) -> Dict[str, Any]: + if isinstance(obj, dict) and ('result' in obj or 'error' in obj): + return ensure_canonical_json({'schema_version': 'tool_res@v1', **obj}, strict=False) + return ensure_canonical_json( + {'schema_version': 'tool_res@v1', 'result': obj, 'error': None}, strict=False + ) diff --git a/examples/workflow/aio/tracing_interceptors_example.py b/examples/workflow/aio/tracing_interceptors_example.py new file mode 100644 index 00000000..ea4834fb --- /dev/null +++ b/examples/workflow/aio/tracing_interceptors_example.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Any, Callable + +from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityInput, + CallChildWorkflowInput, + DaprWorkflowClient, + ExecuteActivityInput, + ExecuteWorkflowInput, + ScheduleWorkflowInput, + WorkflowRuntime, +) + +TRACE_ID_KEY = 'otel.trace_id' +SPAN_ID_KEY = 'otel.span_id' + + +class TracingClientInterceptor(BaseClientInterceptor): + def __init__(self, get_current_trace: Callable[[], tuple[str, str]]): + self._get = get_current_trace + + def schedule_new_workflow(self, input: ScheduleWorkflowInput, next): # type: ignore[override] + trace_id, span_id = self._get() + md = dict(input.metadata or {}) + md[TRACE_ID_KEY] = trace_id + md[SPAN_ID_KEY] = span_id + return next( + ScheduleWorkflowInput( + workflow_name=input.workflow_name, + args=input.args, + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + metadata=md, + local_context=input.local_context, + ) + ) + + +class TracingRuntimeInterceptor(BaseRuntimeInterceptor): + def __init__(self, on_span: Callable[[str, dict[str, str]], Any]): + self._on_span = on_span + + def execute_workflow(self, input: ExecuteWorkflowInput, next): # type: ignore[override] + # Suppress spans during replay + if not input.ctx.is_replaying: + self._on_span('dapr:executeWorkflow', input.metadata or {}) + return next(input) + + def execute_activity(self, input: ExecuteActivityInput, next): # type: ignore[override] + self._on_span('dapr:executeActivity', input.metadata or {}) + return next(input) + + +class TracingWorkflowOutboundInterceptor(BaseWorkflowOutboundInterceptor): + def __init__(self, get_current_trace: Callable[[], tuple[str, str]]): + self._get = get_current_trace + + def call_activity(self, input: CallActivityInput, next): # type: ignore[override] + trace_id, span_id = self._get() + md = dict((input.metadata or {}) or {}) + md[TRACE_ID_KEY] = md.get(TRACE_ID_KEY, trace_id) + md[SPAN_ID_KEY] = span_id + return next( + type(input)( + activity_name=input.activity_name, + args=input.args, + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=md, + local_context=input.local_context, + ) + ) + + def call_child_workflow(self, input: CallChildWorkflowInput, next): # type: ignore[override] + trace_id, span_id = self._get() + md = dict((input.metadata or {}) or {}) + md[TRACE_ID_KEY] = md.get(TRACE_ID_KEY, trace_id) + md[SPAN_ID_KEY] = span_id + return next( + type(input)( + workflow_name=input.workflow_name, + args=input.args, + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=md, + local_context=input.local_context, + ) + ) + + +def example_usage(): + # Simplified trace getter and span recorder + def _get_trace(): + return ('trace-123', 'span-abc') + + spans: list[tuple[str, dict[str, str]]] = [] + + def _on_span(name: str, attrs: dict[str, str]): + spans.append((name, attrs)) + + runtime = WorkflowRuntime( + runtime_interceptors=[TracingRuntimeInterceptor(_on_span)], + workflow_outbound_interceptors=[TracingWorkflowOutboundInterceptor(_get_trace)], + ) + + client = DaprWorkflowClient(interceptors=[TracingClientInterceptor(_get_trace)]) + + # Register and run as you would normally; spans list can be asserted in tests + return runtime, client, spans + + +if __name__ == '__main__': # pragma: no cover + example_usage() diff --git a/examples/workflow/e2e_execinfo.py b/examples/workflow/e2e_execinfo.py new file mode 100644 index 00000000..91f0b295 --- /dev/null +++ b/examples/workflow/e2e_execinfo.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import time + +from dapr.ext.workflow import DaprWorkflowClient, WorkflowRuntime + + +def main(): + port = '50001' + + rt = WorkflowRuntime(port=port) + + def activity_noop(ctx): + ei = ctx.execution_info + # Return attempt (may be None if engine doesn't set it) + return { + 'attempt': ei.attempt if ei else None, + 'workflow_id': ei.workflow_id if ei else None, + } + + @rt.workflow(name='child-to-parent') + def child(ctx, x): + ei = ctx.execution_info + out = yield ctx.call_activity(activity_noop, input=None) + return { + 'child_workflow_name': ei.workflow_name if ei else None, + 'parent_instance_id': ei.parent_instance_id if ei else None, + 'activity': out, + } + + @rt.workflow(name='parent') + def parent(ctx, x): + res = yield ctx.call_child_workflow(child, input={'x': x}) + return res + + rt.register_activity(activity_noop, name='activity_noop') + + rt.start() + try: + # Wait for the worker to be ready to accept work + rt.wait_for_ready(timeout=10) + + client = DaprWorkflowClient(port=port) + instance_id = client.schedule_new_workflow(parent, input=1) + state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=30) + print('instance:', instance_id) + print('runtime_status:', state.runtime_status if state else None) + print('state:', state) + finally: + # Give a moment for logs to flush then shutdown + time.sleep(0.5) + rt.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/requirements.txt b/examples/workflow/requirements.txt index c5af70b9..367e80be 100644 --- a/examples/workflow/requirements.txt +++ b/examples/workflow/requirements.txt @@ -1,2 +1,12 @@ -dapr-ext-workflow>=1.16.0.dev -dapr>=1.16.0.dev +# dapr-ext-workflow-dev>=1.16.0.dev +# dapr-dev>=1.16.0.dev + +# local development: install local packages in editable mode + +# if using dev version of durabletask-python +-e ../../../durabletask-python + +# if using dev version of dapr-ext-workflow +-e ../../ext/dapr-ext-workflow +-e ../.. + diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py b/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py index dd2d45b7..b6a75e47 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py @@ -14,9 +14,39 @@ """ # Import your main classes here +from dapr.ext.workflow.aio import AsyncWorkflowContext from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext, when_all, when_any +from dapr.ext.workflow.execution_info import ActivityExecutionInfo, WorkflowExecutionInfo +from dapr.ext.workflow.interceptors import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ClientInterceptor, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + ScheduleWorkflowRequest, + WorkflowOutboundInterceptor, + compose_runtime_chain, + compose_workflow_outbound_chain, +) from dapr.ext.workflow.retry_policy import RetryPolicy +from dapr.ext.workflow.serializers import ( + ActivityIOAdapter, + CanonicalSerializable, + GenericSerializer, + ensure_canonical_json, + get_activity_adapter, + get_serializer, + register_activity_adapter, + register_serializer, + serialize_activity_input, + serialize_activity_output, + use_activity_adapter, +) from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name from dapr.ext.workflow.workflow_state import WorkflowState, WorkflowStatus @@ -25,6 +55,7 @@ 'WorkflowRuntime', 'DaprWorkflowClient', 'DaprWorkflowContext', + 'AsyncWorkflowContext', 'WorkflowActivityContext', 'WorkflowState', 'WorkflowStatus', @@ -32,4 +63,32 @@ 'when_any', 'alternate_name', 'RetryPolicy', + # interceptors + 'ClientInterceptor', + 'BaseClientInterceptor', + 'WorkflowOutboundInterceptor', + 'BaseWorkflowOutboundInterceptor', + 'RuntimeInterceptor', + 'BaseRuntimeInterceptor', + 'ScheduleWorkflowRequest', + 'CallChildWorkflowRequest', + 'CallActivityRequest', + 'ExecuteWorkflowRequest', + 'ExecuteActivityRequest', + 'compose_workflow_outbound_chain', + 'compose_runtime_chain', + 'WorkflowExecutionInfo', + 'ActivityExecutionInfo', + # serializers + 'CanonicalSerializable', + 'GenericSerializer', + 'ActivityIOAdapter', + 'ensure_canonical_json', + 'register_serializer', + 'get_serializer', + 'register_activity_adapter', + 'get_activity_adapter', + 'use_activity_adapter', + 'serialize_activity_input', + 'serialize_activity_output', ] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py new file mode 100644 index 00000000..a195bc0a --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py @@ -0,0 +1,43 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +# Note: Do not import WorkflowRuntime here to avoid circular imports +# Re-export async context and awaitables +from .async_context import AsyncWorkflowContext # noqa: F401 +from .async_driver import CoroutineOrchestratorRunner # noqa: F401 +from .awaitables import ( # noqa: F401 + ActivityAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, +) + +""" +Async I/O surface for Dapr Workflow extension. + +This package provides explicit async-focused imports that mirror the top-level +exports, improving discoverability and aligning with dapr.aio patterns. +""" + +__all__ = [ + 'AsyncWorkflowContext', + 'CoroutineOrchestratorRunner', + 'ActivityAwaitable', + 'SubOrchestratorAwaitable', + 'SleepAwaitable', + 'WhenAllAwaitable', + 'WhenAnyAwaitable', +] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_context.py new file mode 100644 index 00000000..218a1eb3 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_context.py @@ -0,0 +1,167 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any, Awaitable, Callable, Sequence + +from durabletask import task +from durabletask.aio.awaitables import gather as _dt_gather # type: ignore[import-not-found] +from durabletask.deterministic import ( # type: ignore[F401] + DeterministicContextMixin, +) + +from .awaitables import ( + ActivityAwaitable, + ExternalEventAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, +) + +""" +Async workflow context that exposes deterministic awaitables for activities, timers, +external events, and concurrency, along with deterministic utilities. +""" + + +class AsyncWorkflowContext(DeterministicContextMixin): + def __init__(self, base_ctx: task.OrchestrationContext): + self._base_ctx = base_ctx + + # Core workflow metadata parity with sync context + @property + def instance_id(self) -> str: + return self._base_ctx.instance_id + + @property + def current_utc_datetime(self) -> datetime: + return self._base_ctx.current_utc_datetime + + # Activities & Sub-orchestrations + def call_activity( + self, + activity_fn: Callable[..., Any], + *, + input: Any = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + ) -> Awaitable[Any]: + return ActivityAwaitable( + self._base_ctx, activity_fn, input=input, retry_policy=retry_policy, metadata=metadata + ) + + def call_child_workflow( + self, + workflow_fn: Callable[..., Any], + *, + input: Any = None, + instance_id: str | None = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + ) -> Awaitable[Any]: + return SubOrchestratorAwaitable( + self._base_ctx, + workflow_fn, + input=input, + instance_id=instance_id, + retry_policy=retry_policy, + metadata=metadata, + ) + + @property + def is_replaying(self) -> bool: + return self._base_ctx.is_replaying + + # Timers & Events + def create_timer(self, fire_at: float | timedelta | datetime) -> Awaitable[None]: + # If float provided, interpret as seconds + if isinstance(fire_at, (int, float)): + fire_at = timedelta(seconds=float(fire_at)) + return SleepAwaitable(self._base_ctx, fire_at) + + def sleep(self, duration: float | timedelta | datetime) -> Awaitable[None]: + return self.create_timer(duration) + + def wait_for_external_event(self, name: str) -> Awaitable[Any]: + return ExternalEventAwaitable(self._base_ctx, name) + + # Concurrency + def when_all(self, awaitables: Sequence[Awaitable[Any]]) -> Awaitable[list[Any]]: + return WhenAllAwaitable(awaitables) + + def when_any(self, awaitables: Sequence[Awaitable[Any]]) -> Awaitable[Any]: + return WhenAnyAwaitable(awaitables) + + def gather(self, *aws: Awaitable[Any], return_exceptions: bool = False) -> Awaitable[list[Any]]: + return _dt_gather(*aws, return_exceptions=return_exceptions) + + # Deterministic utilities are provided by mixin (now, random, uuid4, new_guid) + + @property + def is_suspended(self) -> bool: + # Placeholder; will be wired when Durable Task exposes this state in context + return self._base_ctx.is_suspended + + # Pass-throughs for completeness + def set_custom_status(self, custom_status: str) -> None: + if hasattr(self._base_ctx, 'set_custom_status'): + self._base_ctx.set_custom_status(custom_status) + + def continue_as_new( + self, + new_input: Any, + *, + save_events: bool = False, + carryover_metadata: bool | dict[str, str] = False, + carryover_headers: bool | dict[str, str] | None = None, + ) -> None: + effective_carryover = ( + carryover_headers if carryover_headers is not None else carryover_metadata + ) + # Try extended signature; fall back to minimal for older fakes/contexts + try: + self._base_ctx.continue_as_new( + new_input, save_events=save_events, carryover_metadata=effective_carryover + ) + except TypeError: + self._base_ctx.continue_as_new(new_input, save_events=save_events) + + # Metadata parity + def set_metadata(self, metadata: dict[str, str] | None) -> None: + setter = getattr(self._base_ctx, 'set_metadata', None) + if callable(setter): + setter(metadata) + + def get_metadata(self) -> dict[str, str] | None: + getter = getattr(self._base_ctx, 'get_metadata', None) + return getter() if callable(getter) else None + + # Header aliases (ergonomic alias for users familiar with Temporal terminology) + def set_headers(self, headers: dict[str, str] | None) -> None: + self.set_metadata(headers) + + def get_headers(self) -> dict[str, str] | None: + return self.get_metadata() + + # Execution info parity + @property + def execution_info(self): # type: ignore[override] + return getattr(self._base_ctx, 'execution_info', None) + + +__all__ = [ + 'AsyncWorkflowContext', +] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_driver.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_driver.py new file mode 100644 index 00000000..7d964174 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_driver.py @@ -0,0 +1,95 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from typing import Any, Awaitable, Callable, Generator, Optional + +from durabletask import task +from durabletask.aio.sandbox import SandboxMode, sandbox_scope + + +class CoroutineOrchestratorRunner: + """Wraps an async orchestrator into a generator-compatible runner.""" + + def __init__( + self, + async_orchestrator: Callable[..., Awaitable[Any]], + *, + sandbox_mode: SandboxMode = SandboxMode.OFF, + ): + self._async_orchestrator = async_orchestrator + self._sandbox_mode = sandbox_mode + + def to_generator( + self, async_ctx: Any, input_data: Optional[Any] + ) -> Generator[task.Task, Any, Any]: + # Instantiate the coroutine with or without input depending on signature/usage + try: + if input_data is None: + coro = self._async_orchestrator(async_ctx) + else: + coro = self._async_orchestrator(async_ctx, input_data) + except TypeError: + # Fallback for orchestrators that only accept a single ctx arg + coro = self._async_orchestrator(async_ctx) + + # Prime the coroutine + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited = coro.send(None) + else: + with sandbox_scope(async_ctx, self._sandbox_mode): + awaited = coro.send(None) + except StopIteration as stop: + return stop.value # type: ignore[misc] + + # Drive the coroutine by yielding the underlying Durable Task(s) + while True: + try: + result = yield awaited + if self._sandbox_mode == SandboxMode.OFF: + awaited = coro.send(result) + else: + with sandbox_scope(async_ctx, self._sandbox_mode): + awaited = coro.send(result) + except StopIteration as stop: + return stop.value + except Exception as exc: + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited = coro.throw(exc) + else: + with sandbox_scope(async_ctx, self._sandbox_mode): + awaited = coro.throw(exc) + except StopIteration as stop: + return stop.value + except BaseException as base_exc: + # Handle cancellation that may not derive from Exception in some environments + try: + import asyncio as _asyncio # local import to avoid hard dep at module import time + + is_cancel = isinstance(base_exc, _asyncio.CancelledError) + except Exception: + is_cancel = False + if is_cancel: + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited = coro.throw(base_exc) + else: + with sandbox_scope(async_ctx, self._sandbox_mode): + awaited = coro.throw(base_exc) + except StopIteration as stop: + return stop.value + continue + raise diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/awaitables.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/awaitables.py new file mode 100644 index 00000000..494887de --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/awaitables.py @@ -0,0 +1,120 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any, Callable + +from durabletask import task +from durabletask.aio.awaitables import ( + AwaitableBase as _BaseAwaitable, # type: ignore[import-not-found] +) +from durabletask.aio.awaitables import ( + ExternalEventAwaitable as _DTExternalEventAwaitable, +) +from durabletask.aio.awaitables import ( + SleepAwaitable as _DTSleepAwaitable, +) +from durabletask.aio.awaitables import ( + WhenAllAwaitable as _DTWhenAllAwaitable, +) + +AwaitableBase = _BaseAwaitable + + +class ActivityAwaitable(AwaitableBase): + def __init__( + self, + ctx: Any, + activity_fn: Callable[..., Any], + *, + input: Any = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + app_id: str | None = None, + ): + self._ctx = ctx + self._activity_fn = activity_fn + self._input = input + self._retry_policy = retry_policy + self._metadata = metadata + self._app_id = app_id + + def _to_task(self) -> task.Task: + return self._ctx.call_activity( + self._activity_fn, + input=self._input, + retry_policy=self._retry_policy, + metadata=self._metadata, + app_id=self._app_id, + ) + + +class SubOrchestratorAwaitable(AwaitableBase): + def __init__( + self, + ctx: Any, + workflow_fn: Callable[..., Any], + *, + input: Any = None, + instance_id: str | None = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + app_id: str | None = None, + ): + self._ctx = ctx + self._workflow_fn = workflow_fn + self._input = input + self._instance_id = instance_id + self._retry_policy = retry_policy + self._metadata = metadata + self._app_id = app_id + + def _to_task(self) -> task.Task: + return self._ctx.call_child_workflow( + self._workflow_fn, + input=self._input, + instance_id=self._instance_id, + retry_policy=self._retry_policy, + metadata=self._metadata, + app_id=self._app_id, + ) + + +class SleepAwaitable(_DTSleepAwaitable): + pass + + +class ExternalEventAwaitable(_DTExternalEventAwaitable): + pass + + +class WhenAllAwaitable(_DTWhenAllAwaitable): + pass + + +class WhenAnyAwaitable(AwaitableBase): + def __init__(self, tasks_like: Iterable[AwaitableBase | task.Task]): + self._tasks_like = list(tasks_like) + + def _to_task(self) -> task.Task: + underlying: list[task.Task] = [] + for a in self._tasks_like: + if isinstance(a, AwaitableBase): + underlying.append(a._to_task()) # type: ignore[attr-defined] + elif isinstance(a, task.Task): + underlying.append(a) + else: + raise TypeError('when_any expects AwaitableBase or durabletask.task.Task') + return task.when_any(underlying) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/sandbox.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/sandbox.py new file mode 100644 index 00000000..3d887e17 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/sandbox.py @@ -0,0 +1,233 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio as _asyncio +import random as _random +import time as _time +import uuid as _uuid +from contextlib import ContextDecorator +from typing import Any + +from durabletask.aio.sandbox import SandboxMode +from durabletask.deterministic import deterministic_random, deterministic_uuid4 + +""" +Scoped sandbox patching for async workflows (best-effort, strict). +""" + + +def _ctx_instance_id(async_ctx: Any) -> str: + if hasattr(async_ctx, 'instance_id'): + return getattr(async_ctx, 'instance_id') + if hasattr(async_ctx, '_base_ctx') and hasattr(async_ctx._base_ctx, 'instance_id'): + return async_ctx._base_ctx.instance_id + return '' + + +def _ctx_now(async_ctx: Any): + if hasattr(async_ctx, 'now'): + try: + return async_ctx.now() + except Exception: + pass + if hasattr(async_ctx, 'current_utc_datetime'): + return async_ctx.current_utc_datetime + if hasattr(async_ctx, '_base_ctx') and hasattr(async_ctx._base_ctx, 'current_utc_datetime'): + return async_ctx._base_ctx.current_utc_datetime + import datetime as _dt + + return _dt.datetime.utcfromtimestamp(0) + + +class _Sandbox(ContextDecorator): + def __init__(self, async_ctx: Any, mode: str): + self._async_ctx = async_ctx + self._mode = mode + self._saved: dict[str, Any] = {} + + def __enter__(self): + self._saved['asyncio.sleep'] = _asyncio.sleep + self._saved['asyncio.gather'] = getattr(_asyncio, 'gather', None) + self._saved['asyncio.create_task'] = getattr(_asyncio, 'create_task', None) + self._saved['random.random'] = _random.random + self._saved['random.randrange'] = _random.randrange + self._saved['random.randint'] = _random.randint + self._saved['uuid.uuid4'] = _uuid.uuid4 + self._saved['time.time'] = _time.time + self._saved['time.time_ns'] = getattr(_time, 'time_ns', None) + + rnd = deterministic_random(_ctx_instance_id(self._async_ctx), _ctx_now(self._async_ctx)) + + async def _sleep_patched(delay: float, result: Any = None): # type: ignore[override] + try: + if float(delay) <= 0: + return await self._saved['asyncio.sleep'](0) + except Exception: + return await self._saved['asyncio.sleep'](delay) # type: ignore[arg-type] + + await self._async_ctx.sleep(delay) + return result + + def _random_patched() -> float: + return rnd.random() + + def _randrange_patched(start, stop=None, step=1): + return rnd.randrange(start, stop, step) if stop is not None else rnd.randrange(start) + + def _randint_patched(a, b): + return rnd.randint(a, b) + + def _uuid4_patched(): + return deterministic_uuid4(rnd) + + def _time_patched() -> float: + return float(_ctx_now(self._async_ctx).timestamp()) + + def _time_ns_patched() -> int: + return int(_ctx_now(self._async_ctx).timestamp() * 1_000_000_000) + + def _create_task_blocked(coro, *args, **kwargs): + try: + close = getattr(coro, 'close', None) + if callable(close): + try: + close() + except Exception: + pass + finally: + raise RuntimeError( + 'asyncio.create_task is not allowed inside workflow (strict mode)' + ) + + def _is_workflow_awaitable(obj: Any) -> bool: + try: + if hasattr(obj, '_to_dapr_task') or hasattr(obj, '_to_task'): + return True + except Exception: + pass + try: + from durabletask import task as _dt + + if isinstance(obj, _dt.Task): + return True + except Exception: + pass + return False + + class _OneShot: + def __init__(self, factory): + self._factory = factory + self._done = False + self._res: Any = None + self._exc: BaseException | None = None + + def __await__(self): # type: ignore[override] + if self._done: + + async def _replay(): + if self._exc is not None: + raise self._exc + return self._res + + return _replay().__await__() + + async def _compute(): + try: + out = await self._factory() + self._res = out + self._done = True + return out + except BaseException as e: # noqa: BLE001 + self._exc = e + self._done = True + raise + + return _compute().__await__() + + def _patched_gather(*aws: Any, return_exceptions: bool = False): # type: ignore[override] + if not aws: + + async def _empty(): + return [] + + return _OneShot(_empty) + + if all(_is_workflow_awaitable(a) for a in aws): + + async def _await_when_all(): + from dapr.ext.workflow.aio.awaitables import WhenAllAwaitable # local import + + combined = WhenAllAwaitable(list(aws)) + return await combined + + return _OneShot(_await_when_all) + + async def _run_mixed(): + results = [] + for a in aws: + try: + results.append(await a) + except Exception as e: # noqa: BLE001 + if return_exceptions: + results.append(e) + else: + raise + return results + + return _OneShot(_run_mixed) + + _asyncio.sleep = _sleep_patched # type: ignore[assignment] + if self._saved['asyncio.gather'] is not None: + _asyncio.gather = _patched_gather # type: ignore[assignment] + _random.random = _random_patched # type: ignore[assignment] + _random.randrange = _randrange_patched # type: ignore[assignment] + _random.randint = _randint_patched # type: ignore[assignment] + _uuid.uuid4 = _uuid4_patched # type: ignore[assignment] + _time.time = _time_patched # type: ignore[assignment] + if self._saved['time.time_ns'] is not None: + _time.time_ns = _time_ns_patched # type: ignore[assignment] + if self._mode == 'strict' and self._saved['asyncio.create_task'] is not None: + _asyncio.create_task = _create_task_blocked # type: ignore[assignment] + + return self + + def __exit__(self, exc_type, exc, tb): + _asyncio.sleep = self._saved['asyncio.sleep'] # type: ignore[assignment] + if self._saved['asyncio.gather'] is not None: + _asyncio.gather = self._saved['asyncio.gather'] # type: ignore[assignment] + if self._saved['asyncio.create_task'] is not None: + _asyncio.create_task = self._saved['asyncio.create_task'] # type: ignore[assignment] + _random.random = self._saved['random.random'] # type: ignore[assignment] + _random.randrange = self._saved['random.randrange'] # type: ignore[assignment] + _random.randint = self._saved['random.randint'] # type: ignore[assignment] + _uuid.uuid4 = self._saved['uuid.uuid4'] # type: ignore[assignment] + _time.time = self._saved['time.time'] # type: ignore[assignment] + if self._saved['time.time_ns'] is not None: + _time.time_ns = self._saved['time.time_ns'] # type: ignore[assignment] + return False + + +def sandbox_scope(async_ctx: Any, mode: SandboxMode): + if mode == SandboxMode.OFF: + + class _Null(ContextDecorator): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + return _Null() + return _Sandbox(async_ctx, 'strict' if mode == SandboxMode.STRICT else 'best_effort') diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py index 461bfd43..0faa64a8 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py @@ -19,6 +19,12 @@ from typing import Any, Optional, TypeVar import durabletask.internal.orchestrator_service_pb2 as pb +from dapr.ext.workflow.interceptors import ( + ClientInterceptor, + ScheduleWorkflowRequest, + compose_client_chain, + wrap_payload_with_metadata, +) from dapr.ext.workflow.logger import Logger, LoggerOptions from dapr.ext.workflow.util import getAddress from dapr.ext.workflow.workflow_context import Workflow @@ -51,6 +57,8 @@ def __init__( host: Optional[str] = None, port: Optional[str] = None, logger_options: Optional[LoggerOptions] = None, + *, + interceptors: list[ClientInterceptor] | None = None, ): address = getAddress(host, port) @@ -61,18 +69,31 @@ def __init__( self._logger = Logger('DaprWorkflowClient', logger_options) - metadata = tuple() + metadata = () if settings.DAPR_API_TOKEN: metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),) options = self._logger.get_options() + # Optional gRPC channel options (keepalive, retry policy) via helpers + # channel_options = build_grpc_channel_options() + + # Construct base kwargs for TaskHubGrpcClient + base_kwargs = { + 'host_address': uri.endpoint, + 'metadata': metadata, + 'secure_channel': uri.tls, + 'log_handler': options.log_handler, + 'log_formatter': options.log_formatter, + } + + # Initialize TaskHubGrpcClient (DurableTask supports options) self.__obj = client.TaskHubGrpcClient( - host_address=uri.endpoint, - metadata=metadata, - secure_channel=uri.tls, - log_handler=options.log_handler, - log_formatter=options.log_formatter, + **base_kwargs, + # channel_options=channel_options, ) + # Interceptors + self._client_interceptors: list[ClientInterceptor] = list(interceptors or []) + def schedule_new_workflow( self, workflow: Workflow, @@ -81,6 +102,7 @@ def schedule_new_workflow( instance_id: Optional[str] = None, start_at: Optional[datetime] = None, reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None, + metadata: dict[str, str] | None = None, ) -> str: """Schedules a new workflow instance for execution. @@ -95,25 +117,39 @@ def schedule_new_workflow( be scheduled immediately. reuse_id_policy: Optional policy to reuse the workflow id when there is a conflict with an existing workflow instance. + metadata (dict[str, str] | None): Optional dictionary of key-value pairs + to be included as metadata/headers for the workflow. Returns: The ID of the scheduled workflow instance. """ - if hasattr(workflow, '_dapr_alternate_name'): + wf_name = ( + workflow.__dict__['_dapr_alternate_name'] + if hasattr(workflow, '_dapr_alternate_name') + else workflow.__name__ + ) + + # Build interceptor chain around schedule call + def terminal(term_req: ScheduleWorkflowRequest) -> str: + payload = wrap_payload_with_metadata(term_req.input, term_req.metadata) return self.__obj.schedule_new_orchestration( - workflow.__dict__['_dapr_alternate_name'], - input=input, - instance_id=instance_id, - start_at=start_at, - reuse_id_policy=reuse_id_policy, + term_req.workflow_name, + input=payload, + instance_id=term_req.instance_id, + start_at=term_req.start_at, + reuse_id_policy=term_req.reuse_id_policy, ) - return self.__obj.schedule_new_orchestration( - workflow.__name__, + + chain = compose_client_chain(self._client_interceptors, terminal) + schedule_req = ScheduleWorkflowRequest( + workflow_name=wf_name, input=input, instance_id=instance_id, start_at=start_at, reuse_id_policy=reuse_id_policy, + metadata=metadata, ) + return chain(schedule_req) def get_workflow_state( self, instance_id: str, *, fetch_payloads: bool = True diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py index 714def3f..9701769a 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,28 +11,61 @@ limitations under the License. """ +import enum from datetime import datetime, timedelta from typing import Any, Callable, List, Optional, TypeVar, Union +from dapr.ext.workflow.execution_info import WorkflowExecutionInfo +from dapr.ext.workflow.interceptors import unwrap_payload_with_metadata, wrap_payload_with_metadata from dapr.ext.workflow.logger import Logger, LoggerOptions from dapr.ext.workflow.retry_policy import RetryPolicy from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext from dapr.ext.workflow.workflow_context import Workflow, WorkflowContext from durabletask import task +from durabletask.deterministic import ( # type: ignore[F401] + DeterministicContextMixin, +) T = TypeVar('T') TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') -class DaprWorkflowContext(WorkflowContext): - """DaprWorkflowContext that provides proxy access to internal OrchestrationContext instance.""" +class Handlers(enum.Enum): + CALL_ACTIVITY = 'call_activity' + CALL_CHILD_WORKFLOW = 'call_child_workflow' + CONTINUE_AS_NEW = 'continue_as_new' + + +class DaprWorkflowContext(WorkflowContext, DeterministicContextMixin): + """Workflow context wrapper with deterministic utilities and metadata helpers. + + Purpose + ------- + - Proxy to the underlying ``durabletask.task.OrchestrationContext`` (engine fields like + ``trace_parent``, ``orchestration_span_id``, and ``workflow_attempt`` pass through). + - Provide SDK-level helpers for durable metadata propagation via interceptors. + - Expose ``execution_info`` as a per-activation snapshot complementing live properties. + + Tips + ---- + - Use ``ctx.get_metadata()/set_metadata()`` to manage outbound propagation. + - Use ``ctx.execution_info.inbound_metadata`` to inspect what arrived on this activation. + - Prefer engine-backed properties for tracing/attempts when available (not yet available in dapr sidecar); fall back to + metadata only for app-specific context. + """ def __init__( - self, ctx: task.OrchestrationContext, logger_options: Optional[LoggerOptions] = None + self, + ctx: task.OrchestrationContext, + logger_options: Optional[LoggerOptions] = None, + *, + outbound_handlers: dict[Handlers, Any] | None = None, ): self.__obj = ctx self._logger = Logger('DaprWorkflowContext', logger_options) + self._outbound_handlers = outbound_handlers or {} + self._metadata: dict[str, str] | None = None # provide proxy access to regular attributes of wrapped object def __getattr__(self, name): @@ -52,10 +83,34 @@ def current_utc_datetime(self) -> datetime: def is_replaying(self) -> bool: return self.__obj.is_replaying + # Deterministic utilities are provided by mixin (now, random, uuid4, new_guid) + + # Metadata API + def set_metadata(self, metadata: dict[str, str] | None) -> None: + self._metadata = dict(metadata) if metadata else None + + def get_metadata(self) -> dict[str, str] | None: + return dict(self._metadata) if self._metadata else None + + # Header aliases (ergonomic alias for users familiar with Temporal terminology) + def set_headers(self, headers: dict[str, str] | None) -> None: + self.set_metadata(headers) + + def get_headers(self) -> dict[str, str] | None: + return self.get_metadata() + def set_custom_status(self, custom_status: str) -> None: self._logger.debug(f'{self.instance_id}: Setting custom status to {custom_status}') self.__obj.set_custom_status(custom_status) + # Execution info (populated by runtime when available) + @property + def execution_info(self) -> WorkflowExecutionInfo | None: + return getattr(self, '_execution_info', None) + + def _set_execution_info(self, info: WorkflowExecutionInfo) -> None: + self._execution_info = info + def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: self._logger.debug(f'{self.instance_id}: Creating timer to fire at {fire_at} time') return self.__obj.create_timer(fire_at) @@ -67,6 +122,7 @@ def call_activity( input: TInput = None, retry_policy: Optional[RetryPolicy] = None, app_id: Optional[str] = None, + metadata: dict[str, str] | None = None, ) -> task.Task[TOutput]: # Handle string activity names for cross-app scenarios if isinstance(activity, str): @@ -91,10 +147,18 @@ def call_activity( else: # this case should ideally never happen act = activity.__name__ + # Apply outbound client interceptor transformations if provided via runtime wiring + transformed_input: Any = input + if Handlers.CALL_ACTIVITY in self._outbound_handlers and callable( + self._outbound_handlers[Handlers.CALL_ACTIVITY] + ): + transformed_input = self._outbound_handlers[Handlers.CALL_ACTIVITY]( + self, activity, input, retry_policy, metadata or self.get_metadata() + ) if retry_policy is None: - return self.__obj.call_activity(activity=act, input=input, app_id=app_id) + return self.__obj.call_activity(activity=act, input=transformed_input, app_id=app_id) return self.__obj.call_activity( - activity=act, input=input, retry_policy=retry_policy.obj, app_id=app_id + activity=act, input=transformed_input, retry_policy=retry_policy.obj, app_id=app_id ) def call_child_workflow( @@ -105,6 +169,7 @@ def call_child_workflow( instance_id: Optional[str] = None, retry_policy: Optional[RetryPolicy] = None, app_id: Optional[str] = None, + metadata: dict[str, str] | None = None, ) -> task.Task[TOutput]: # Handle string workflow names for cross-app scenarios if isinstance(workflow, str): @@ -127,8 +192,8 @@ def call_child_workflow( self._logger.debug(f'{self.instance_id}: Creating child workflow {workflow.__name__}') def wf(ctx: task.OrchestrationContext, inp: TInput): - daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) - return workflow(daprWfContext, inp) + dapr_wf_context = DaprWorkflowContext(ctx, self._logger.get_options()) + return workflow(dapr_wf_context, inp) # copy workflow name so durabletask.worker can find the orchestrator in its registry @@ -137,24 +202,79 @@ def wf(ctx: task.OrchestrationContext, inp: TInput): else: # this case should ideally never happen wf.__name__ = workflow.__name__ + # Apply outbound client interceptor transformations if provided via runtime wiring + transformed_input: Any = input + if Handlers.CALL_CHILD_WORKFLOW in self._outbound_handlers and callable( + self._outbound_handlers[Handlers.CALL_CHILD_WORKFLOW] + ): + transformed_input = self._outbound_handlers[Handlers.CALL_CHILD_WORKFLOW]( + self, workflow, input, metadata or self.get_metadata() + ) if retry_policy is None: return self.__obj.call_sub_orchestrator( - wf, input=input, instance_id=instance_id, app_id=app_id + wf, input=transformed_input, instance_id=instance_id, app_id=app_id ) return self.__obj.call_sub_orchestrator( - wf, input=input, instance_id=instance_id, retry_policy=retry_policy.obj, app_id=app_id + wf, + input=transformed_input, + instance_id=instance_id, + retry_policy=retry_policy.obj, + app_id=app_id, ) def wait_for_external_event(self, name: str) -> task.Task: self._logger.debug(f'{self.instance_id}: Waiting for external event {name}') return self.__obj.wait_for_external_event(name) - def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None: + def continue_as_new( + self, + new_input: Any, + *, + save_events: bool = False, + carryover_metadata: bool = False, + metadata: dict[str, str] | None = None, + ) -> None: + """ + Continue the workflow execution with new inputs and optional metadata or headers. + + This method allows restarting the workflow execution with new input parameters, + while optionally preserving workflow events, metadata, and/or headers. It also + integrates with workflow interceptors if configured, enabling custom modification + of inputs and associated metadata before continuation. + + Args: + new_input: Any new input to pass to the workflow upon continuation. + save_events (bool): Indicates whether to preserve the event history of the + workflow execution. Defaults to False. + carryover_metadata bool: If True, carries + over metadata from the current execution. + metadata dict[str, str] | None: If a dictionary is provided, it + will be added to the current metadata. If carryover_metadata is True, + the contents of the dictionary will be merged with the current metadata. + """ self._logger.debug(f'{self.instance_id}: Continuing as new') - self.__obj.continue_as_new(new_input, save_events=save_events) + # Allow workflow outbound interceptors (wired via runtime) to modify payload/metadata + transformed_input: Any = new_input + if Handlers.CONTINUE_AS_NEW in self._outbound_handlers and callable( + self._outbound_handlers[Handlers.CONTINUE_AS_NEW] + ): + transformed_input = self._outbound_handlers[Handlers.CONTINUE_AS_NEW]( + self, new_input, self.get_metadata() + ) + + # Merge/carry metadata if requested, unwrapping any envelope produced by interceptors + payload, base_md = unwrap_payload_with_metadata(transformed_input) + # Start with current context metadata; then layer any interceptor-provided metadata on top + current_md = self.get_metadata() or {} + effective_md = (current_md | (base_md or {})) if carryover_metadata else {} + if metadata is not None: + effective_md = effective_md | metadata + + payload = wrap_payload_with_metadata(payload, effective_md) + self.__obj.continue_as_new(payload, save_events=save_events) -def when_all(tasks: List[task.Task[T]]) -> task.WhenAllTask[T]: +def when_all(tasks: List[task.Task]) -> task.WhenAllTask: """Returns a task that completes when all of the provided tasks complete or when one of the tasks fail.""" return task.when_all(tasks) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/deterministic.py b/ext/dapr-ext-workflow/dapr/ext/workflow/deterministic.py new file mode 100644 index 00000000..d33a02c6 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/deterministic.py @@ -0,0 +1,27 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +# Backward-compatible shim: import deterministic utilities from durabletask +from durabletask.deterministic import ( # type: ignore[F401] + DeterministicContextMixin, + deterministic_random, + deterministic_uuid4, +) + +__all__ = [ + 'DeterministicContextMixin', + 'deterministic_random', + 'deterministic_uuid4', +] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/execution_info.py b/ext/dapr-ext-workflow/dapr/ext/workflow/execution_info.py new file mode 100644 index 00000000..0aacd710 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/execution_info.py @@ -0,0 +1,49 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +""" +Minimal, deterministic snapshots of inbound durable metadata. + +Rationale +--------- + +Execution info previously mirrored many engine fields (IDs, tracing, attempts) already +available on the workflow/activity contexts. To remove redundancy and simplify usage, the +execution info types now only capture the durable ``inbound_metadata`` that was actually +propagated into this activation. Use context properties directly for engine fields. +""" + + +@dataclass +class WorkflowExecutionInfo: + """Per-activation snapshot for workflows. + + Only includes ``inbound_metadata`` that arrived with this activation. + """ + + inbound_metadata: dict[str, str] + + +@dataclass +class ActivityExecutionInfo: + """Per-activation snapshot for activities. + + Only includes ``inbound_metadata`` that arrived with this activity invocation. + """ + + inbound_metadata: dict[str, str] + activity_name: str diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/interceptors.py b/ext/dapr-ext-workflow/dapr/ext/workflow/interceptors.py new file mode 100644 index 00000000..90ebb7ea --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/interceptors.py @@ -0,0 +1,414 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Generic, Protocol, TypeVar + +from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext +from dapr.ext.workflow.workflow_context import WorkflowContext + +# Type variables for generic interceptor payload typing +TInput = TypeVar('TInput') +TWorkflowInput = TypeVar('TWorkflowInput') +TActivityInput = TypeVar('TActivityInput') + +""" +Interceptor interfaces and chain utilities for the Dapr Workflow SDK. + +Providing a single enter/exit around calls. + +IMPORTANT: Generator wrappers for async workflows +-------------------------------------------------- +When writing runtime interceptors that touch workflow execution, be careful with generator +handling. If an interceptor obtains a workflow generator from user code (e.g., an async +orchestrator adapted into a generator) it must not manually iterate it using a for-loop +and yield the produced items. Doing so breaks send()/throw() propagation back into the +inner generator, which can cause resumed results from the durable runtime to be dropped +and appear as None to awaiters. + +Best practices: +- If the interceptor participates in composition and needs to return the generator, + return it directly (do not iterate it). +- If the interceptor must wrap the generator, always use "yield from inner_gen" so that + send()/throw() are forwarded correctly. + +Context managers with async workflows +-------------------------------------- +When using context managers (like ExitStack, logging contexts, or trace contexts) in an +interceptor for async workflows, be aware that calling `next(input)` returns a generator +object immediately, NOT the final result. The generator executes later when the durable +task runtime drives it. + +If you need a context manager to remain active during the workflow execution: + +**WRONG - Context exits before workflow runs:** + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + with setup_context(): + return next(input) # Returns generator, context exits immediately! + +**CORRECT - Context stays active throughout execution:** + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + def wrapper(): + with setup_context(): + gen = next(input) + yield from gen # Keep context alive while generator executes + return wrapper() + +For more complex scenarios with ExitStack or async context managers, wrap the generator +with `yield from` to ensure your context spans the entire workflow execution, including +all replay and continuation events. + +Example with ExitStack: + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + def wrapper(): + with ExitStack() as stack: + # Set up contexts (trace, logging, etc.) + stack.enter_context(trace_context(...)) + stack.enter_context(logging_context(...)) + + # Get the generator from the next interceptor/handler + gen = next(input) + + # Keep contexts alive while generator executes + yield from gen + return wrapper() + +This pattern ensures your context manager remains active during: +- Initial workflow execution +- Replays from durable state +- Continuation after awaits +- Activity calls and child workflow invocations +""" + + +# Context metadata propagation +# ---------------------------- +# "metadata" is a durable, string-only map. It is serialized on the wire and propagates across +# boundaries (client → runtime → activity/child), surviving replays/retries. Use it when downstream +# components must observe the value. In-process ephemeral state should be handled within interceptors +# without attempting to propagate across process boundaries. + + +# ------------------------------ +# Client-side interceptor surface +# ------------------------------ + + +@dataclass +class ScheduleWorkflowRequest(Generic[TInput]): + workflow_name: str + input: TInput + instance_id: str | None + start_at: Any | None + reuse_id_policy: Any | None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +@dataclass +class CallChildWorkflowRequest(Generic[TInput]): + workflow_name: str + input: TInput + instance_id: str | None + # Optional workflow context for outbound calls made inside workflows + workflow_ctx: Any | None = None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +@dataclass +class ContinueAsNewRequest(Generic[TInput]): + input: TInput + # Optional workflow context for outbound calls made inside workflows + workflow_ctx: Any | None = None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +@dataclass +class CallActivityRequest(Generic[TInput]): + activity_name: str + input: TInput + retry_policy: Any | None + # Optional workflow context for outbound calls made inside workflows + workflow_ctx: Any | None = None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +class ClientInterceptor(Protocol, Generic[TInput]): + def schedule_new_workflow( + self, + input: ScheduleWorkflowRequest[TInput], + next: Callable[[ScheduleWorkflowRequest[TInput]], Any], + ) -> Any: ... + + +# ------------------------------- +# Runtime-side interceptor surface +# ------------------------------- + + +@dataclass +class ExecuteWorkflowRequest(Generic[TInput]): + ctx: WorkflowContext + input: TInput + # Durable metadata (runtime chain only; not injected into user code) + metadata: dict[str, str] | None = None + + +@dataclass +class ExecuteActivityRequest(Generic[TInput]): + ctx: WorkflowActivityContext + input: TInput + # Durable metadata (runtime chain only; not injected into user code) + metadata: dict[str, str] | None = None + + +class RuntimeInterceptor(Protocol, Generic[TWorkflowInput, TActivityInput]): + def execute_workflow( + self, + input: ExecuteWorkflowRequest[TWorkflowInput], + next: Callable[[ExecuteWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: ... + + def execute_activity( + self, + input: ExecuteActivityRequest[TActivityInput], + next: Callable[[ExecuteActivityRequest[TActivityInput]], Any], + ) -> Any: ... + + +# ------------------------------ +# Convenience base classes (devex) +# ------------------------------ + + +class BaseClientInterceptor(Generic[TInput]): + """Subclass this to get method name completion and safe defaults. + + Override any of the methods to customize behavior. By default, these + methods simply call `next` unchanged. + """ + + def schedule_new_workflow( + self, + input: ScheduleWorkflowRequest[TInput], + next: Callable[[ScheduleWorkflowRequest[TInput]], Any], + ) -> Any: # noqa: D401 + return next(input) + + # No workflow-outbound methods here; use WorkflowOutboundInterceptor for those + + +class BaseRuntimeInterceptor(Generic[TWorkflowInput, TActivityInput]): + """Subclass this to get method name completion and safe defaults.""" + + def execute_workflow( + self, + input: ExecuteWorkflowRequest[TWorkflowInput], + next: Callable[[ExecuteWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: # noqa: D401 + return next(input) + + def execute_activity( + self, + input: ExecuteActivityRequest[TActivityInput], + next: Callable[[ExecuteActivityRequest[TActivityInput]], Any], + ) -> Any: # noqa: D401 + return next(input) + + +# ------------------------------ +# Helper: chain composition +# ------------------------------ + + +def compose_client_chain( + interceptors: list[ClientInterceptor], terminal: Callable[[Any], Any] +) -> Callable[[Any], Any]: + """Compose client interceptors into a single callable. + + Interceptors are applied in list order; each receives a ``next``. + The ``terminal`` callable is the final handler invoked after all interceptors; it + performs the base operation (e.g., scheduling the workflow) when the chain ends. + """ + next_fn = terminal + for icpt in reversed(interceptors or []): + + def make_next(curr_icpt: ClientInterceptor, nxt: Callable[[Any], Any]): + def runner(input: Any) -> Any: + if isinstance(input, ScheduleWorkflowRequest): + return curr_icpt.schedule_new_workflow(input, nxt) + return nxt(input) + + return runner + + next_fn = make_next(icpt, next_fn) + return next_fn + + +# ------------------------------ +# Workflow outbound interceptor surface +# ------------------------------ + + +class WorkflowOutboundInterceptor(Protocol, Generic[TWorkflowInput, TActivityInput]): + def call_child_workflow( + self, + input: CallChildWorkflowRequest[TWorkflowInput], + next: Callable[[CallChildWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: ... + + def continue_as_new( + self, + input: ContinueAsNewRequest[TWorkflowInput], + next: Callable[[ContinueAsNewRequest[TWorkflowInput]], Any], + ) -> Any: ... + + def call_activity( + self, + input: CallActivityRequest[TActivityInput], + next: Callable[[CallActivityRequest[TActivityInput]], Any], + ) -> Any: ... + + +class BaseWorkflowOutboundInterceptor(Generic[TWorkflowInput, TActivityInput]): + def call_child_workflow( + self, + input: CallChildWorkflowRequest[TWorkflowInput], + next: Callable[[CallChildWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: + return next(input) + + def continue_as_new( + self, + input: ContinueAsNewRequest[TWorkflowInput], + next: Callable[[ContinueAsNewRequest[TWorkflowInput]], Any], + ) -> Any: + return next(input) + + def call_activity( + self, + input: CallActivityRequest[TActivityInput], + next: Callable[[CallActivityRequest[TActivityInput]], Any], + ) -> Any: + return next(input) + + +# ------------------------------ +# Backward-compat typing aliases +# ------------------------------ + + +def compose_workflow_outbound_chain( + interceptors: list[WorkflowOutboundInterceptor], + terminal: Callable[[Any], Any], +) -> Callable[[Any], Any]: + """Compose workflow outbound interceptors into a single callable. + + Interceptors are applied in list order; each receives a ``next``. + The ``terminal`` callable is the final handler invoked after all interceptors; it + performs the base operation (e.g., preparing outbound call args) when the chain ends. + """ + next_fn = terminal + for icpt in reversed(interceptors or []): + + def make_next(curr_icpt: WorkflowOutboundInterceptor, nxt: Callable[[Any], Any]): + def runner(input: Any) -> Any: + # Dispatch to the appropriate outbound method on the interceptor + if isinstance(input, CallActivityRequest): + return curr_icpt.call_activity(input, nxt) + if isinstance(input, CallChildWorkflowRequest): + return curr_icpt.call_child_workflow(input, nxt) + if isinstance(input, ContinueAsNewRequest): + return curr_icpt.continue_as_new(input, nxt) + # Fallback to next if input type unknown + return nxt(input) + + return runner + + next_fn = make_next(icpt, next_fn) + return next_fn + + +# ------------------------------ +# Helper: envelope for durable metadata +# ------------------------------ + +_META_KEY = '__dapr_meta__' +_META_VERSION = 1 +_PAYLOAD_KEY = '__dapr_payload__' + + +def wrap_payload_with_metadata(payload: Any, metadata: dict[str, str] | None) -> Any: + """If metadata is provided and non-empty, wrap payload in an envelope for persistence. + + Backward compatible: if metadata is falsy, return payload unchanged. + """ + if metadata: + return { + _META_KEY: { + 'v': _META_VERSION, + 'metadata': metadata, + }, + _PAYLOAD_KEY: payload, + } + return payload + + +def unwrap_payload_with_metadata(obj: Any) -> tuple[Any, dict[str, str] | None]: + """Extract payload and metadata from envelope if present. + + Returns (payload, metadata_dict_or_none). + """ + try: + if isinstance(obj, dict) and _META_KEY in obj and _PAYLOAD_KEY in obj: + meta = obj.get(_META_KEY) or {} + md = meta.get('metadata') if isinstance(meta, dict) else None + return obj.get(_PAYLOAD_KEY), md if isinstance(md, dict) else None + except Exception: + # Be robust: on any error, treat as raw payload + pass + return obj, None + + +def compose_runtime_chain( + interceptors: list[RuntimeInterceptor], final_handler: Callable[[Any], Any] +): + """Compose runtime interceptors into a single callable (synchronous). + + The ``final_handler`` callable is the final handler invoked after all interceptors; it + performs the core operation (e.g., calling user workflow/activity or returning a + workflow generator) when the chain ends. + """ + next_fn = final_handler + for icpt in reversed(interceptors or []): + + def make_next(curr_icpt: RuntimeInterceptor, nxt: Callable[[Any], Any]): + def runner(input: Any) -> Any: + if isinstance(input, ExecuteWorkflowRequest): + return curr_icpt.execute_workflow(input, nxt) + if isinstance(input, ExecuteActivityRequest): + return curr_icpt.execute_activity(input, nxt) + return nxt(input) + + return runner + + next_fn = make_next(icpt, next_fn) + return next_fn diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/serializers.py b/ext/dapr-ext-workflow/dapr/ext/workflow/serializers.py new file mode 100644 index 00000000..09af5188 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/serializers.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import json +from collections.abc import MutableMapping, MutableSequence +from typing import ( + Any, + Callable, + Dict, + Optional, + Protocol, + cast, +) + +""" +General-purpose, provider-agnostic JSON serialization helpers for workflow activities. + +This module focuses on generic extension points to ensure activity inputs/outputs are JSON-only +and replay-safe. It intentionally avoids provider-specific shapes (e.g., model/tool contracts), +which should live in examples or external packages. +""" + + +def _is_json_primitive(value: Any) -> bool: + return value is None or isinstance(value, (str, int, float, bool)) + + +def _to_json_safe(value: Any, *, strict: bool) -> Any: + """Convert a Python object to a JSON-serializable structure. + + - Dict keys become strings (lenient) or error (strict) if not str. + - Unsupported values become str(value) (lenient) or error (strict). + """ + + if _is_json_primitive(value): + return value + + if isinstance(value, MutableSequence) or isinstance(value, tuple): + return [_to_json_safe(v, strict=strict) for v in value] + + if isinstance(value, MutableMapping) or isinstance(value, dict): + output: Dict[str, Any] = {} + for k, v in value.items(): + if not isinstance(k, str): + if strict: + raise ValueError('dict keys must be strings in strict mode') + k = str(k) + output[k] = _to_json_safe(v, strict=strict) + return output + + if strict: + # Attempt final json.dumps to surface type + try: + json.dumps(value) + except Exception as err: + raise ValueError(f'non-JSON-serializable value: {type(value).__name__}') from err + return value + + return str(value) + + +def _ensure_json(obj: Any, *, strict: bool) -> Any: + converted = _to_json_safe(obj, strict=strict) + # json.dumps as a final guard + json.dumps(converted) + return converted + + +# ---------------------------------------------------------------------------------------------- +# Generic helpers and extension points +# ---------------------------------------------------------------------------------------------- + + +class CanonicalSerializable(Protocol): + """Objects implementing this can produce a canonical JSON-serializable structure.""" + + def to_canonical_json(self, *, strict: bool = True) -> Any: ... + + +class GenericSerializer(Protocol): + """Serializer that converts arbitrary Python objects to/from JSON-serializable data.""" + + def serialize(self, obj: Any, *, strict: bool = True) -> Any: ... + + def deserialize(self, data: Any) -> Any: ... + + +_SERIALIZERS: Dict[str, GenericSerializer] = {} + + +def register_serializer(name: str, serializer: GenericSerializer) -> None: + if not name: + raise ValueError('serializer name must be non-empty') + _SERIALIZERS[name] = serializer + + +def get_serializer(name: str) -> Optional[GenericSerializer]: + return _SERIALIZERS.get(name) + + +def ensure_canonical_json(obj: Any, *, strict: bool = True) -> Any: + """Ensure any object is converted into a JSON-serializable structure. + + - If the object implements CanonicalSerializable, call to_canonical_json + - Else, coerce via the internal JSON-safe conversion + """ + + if hasattr(obj, 'to_canonical_json') and callable(getattr(obj, 'to_canonical_json')): + return _ensure_json( + cast(CanonicalSerializable, obj).to_canonical_json(strict=strict), strict=strict + ) + return _ensure_json(obj, strict=strict) + + +class ActivityIOAdapter(Protocol): + """Adapter to control how activity inputs/outputs are serialized.""" + + def serialize_input(self, input: Any, *, strict: bool = True) -> Any: ... + + def serialize_output(self, output: Any, *, strict: bool = True) -> Any: ... + + +_ACTIVITY_ADAPTERS: Dict[str, ActivityIOAdapter] = {} + + +def register_activity_adapter(name: str, adapter: ActivityIOAdapter) -> None: + if not name: + raise ValueError('activity adapter name must be non-empty') + _ACTIVITY_ADAPTERS[name] = adapter + + +def get_activity_adapter(name: str) -> Optional[ActivityIOAdapter]: + return _ACTIVITY_ADAPTERS.get(name) + + +def use_activity_adapter( + adapter: ActivityIOAdapter, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Decorator to attach an ActivityIOAdapter to an activity function.""" + + def _decorate(f: Callable[..., Any]) -> Callable[..., Any]: + cast(Any, f).__dapr_activity_io_adapter__ = adapter + return f + + return _decorate + + +def serialize_activity_input(func: Callable[..., Any], input: Any, *, strict: bool = True) -> Any: + adapter = getattr(func, '__dapr_activity_io_adapter__', None) + if adapter: + return cast(ActivityIOAdapter, adapter).serialize_input(input, strict=strict) + return ensure_canonical_json(input, strict=strict) + + +def serialize_activity_output(func: Callable[..., Any], output: Any, *, strict: bool = True) -> Any: + adapter = getattr(func, '__dapr_activity_io_adapter__', None) + if adapter: + return cast(ActivityIOAdapter, adapter).serialize_output(output, strict=strict) + return ensure_canonical_json(output, strict=strict) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py index 331ad6c2..79c53ab8 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py @@ -17,6 +17,7 @@ from typing import Callable, TypeVar +from dapr.ext.workflow.execution_info import ActivityExecutionInfo from durabletask import task T = TypeVar('T') @@ -25,10 +26,20 @@ class WorkflowActivityContext: - """Defines properties and methods for task activity context objects.""" + """Wrapper for ``durabletask.task.ActivityContext`` with metadata helpers. + + Purpose + ------- + - Provide pass-throughs for engine fields (``trace_parent``, ``trace_state``, + and parent ``workflow_span_id`` when available). + - Surface ``execution_info``: a per-activation snapshot that includes the + ``inbound_metadata`` actually received for this activity. + - Offer ``get_metadata()/set_metadata()`` for SDK-level durable metadata management. + """ def __init__(self, ctx: task.ActivityContext): self.__obj = ctx + self._metadata: dict[str, str] | None = None @property def workflow_id(self) -> str: @@ -43,6 +54,20 @@ def task_id(self) -> int: def get_inner_context(self) -> task.ActivityContext: return self.__obj + @property + def execution_info(self) -> ActivityExecutionInfo | None: + return getattr(self, '_execution_info', None) + + def _set_execution_info(self, info: ActivityExecutionInfo) -> None: + self._execution_info = info + + # Metadata accessors (SDK-level; set by runtime inbound if available) + def set_metadata(self, metadata: dict[str, str] | None) -> None: + self._metadata = dict(metadata) if metadata else None + + def get_metadata(self) -> dict[str, str] | None: + return dict(self._metadata) if self._metadata else None + # Activities are simple functions that can be scheduled by workflows Activity = Callable[..., TOutput] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py index 8453e16e..e8c1e640 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,8 +15,9 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Any, Callable, Generator, Optional, TypeVar, Union +from typing import Any, Callable, Generator, TypeVar, Union +from dapr.ext.workflow.retry_policy import RetryPolicy from dapr.ext.workflow.workflow_activity_context import Activity from durabletask import task @@ -90,7 +89,7 @@ def set_custom_status(self, custom_status: str) -> None: pass @abstractmethod - def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: + def create_timer(self, fire_at: datetime | timedelta) -> task.Task: """Create a Timer Task to fire after at the specified deadline. Parameters @@ -110,8 +109,9 @@ def call_activity( self, activity: Union[Activity[TOutput], str], *, - input: Optional[TInput] = None, - app_id: Optional[str] = None, + input: TInput | None = None, + app_id: str | None = None, + retry_policy: RetryPolicy | None = None, ) -> task.Task[TOutput]: """Schedule an activity for execution. @@ -123,6 +123,7 @@ def call_activity( The JSON-serializable input (or None) to pass to the activity. app_id: str | None The AppID that will execute the activity. + retry_policy: RetryPolicy | None Returns ------- @@ -136,9 +137,10 @@ def call_child_workflow( self, orchestrator: Union[Workflow[TOutput], str], *, - input: Optional[TInput] = None, - instance_id: Optional[str] = None, - app_id: Optional[str] = None, + input: TInput | None = None, + instance_id: str | None = None, + app_id: str | None = None, + retry_policy: RetryPolicy | None = None, ) -> task.Task[TOutput]: """Schedule child-workflow function for execution. @@ -153,6 +155,9 @@ def call_child_workflow( random UUID will be used. app_id: str The AppID that will execute the workflow. + retry_policy: RetryPolicy | None + Optional retry policy for the child-workflow. When provided, failures will be retried + according to the policy. Returns ------- diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index 593e55c6..a7098f63 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -13,17 +13,34 @@ limitations under the License. """ +import asyncio import inspect +import traceback from functools import wraps -from typing import Optional, Sequence, TypeVar, Union +from typing import Any, Awaitable, Callable, List, Optional, Sequence, TypeVar, Union import grpc -from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext, Handlers +from dapr.ext.workflow.execution_info import ActivityExecutionInfo, WorkflowExecutionInfo +from dapr.ext.workflow.interceptors import ( + CallActivityRequest, + CallChildWorkflowRequest, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + WorkflowOutboundInterceptor, + compose_runtime_chain, + compose_workflow_outbound_chain, + unwrap_payload_with_metadata, + wrap_payload_with_metadata, +) from dapr.ext.workflow.logger import Logger, LoggerOptions from dapr.ext.workflow.util import getAddress from dapr.ext.workflow.workflow_activity_context import Activity, WorkflowActivityContext from dapr.ext.workflow.workflow_context import Workflow from durabletask import task, worker +from durabletask.aio.sandbox import SandboxMode from dapr.clients import DaprInternalError from dapr.clients.http.client import DAPR_API_TOKEN_HEADER @@ -47,16 +64,19 @@ class WorkflowRuntime: def __init__( self, - host: Optional[str] = None, - port: Optional[str] = None, + host: str | None = None, + port: str | None = None, logger_options: Optional[LoggerOptions] = None, interceptors: Optional[Sequence[ClientInterceptor]] = None, maximum_concurrent_activity_work_items: Optional[int] = None, maximum_concurrent_orchestration_work_items: Optional[int] = None, maximum_thread_pool_workers: Optional[int] = None, + *, + runtime_interceptors: Optional[list[RuntimeInterceptor]] = None, + workflow_outbound_interceptors: Optional[list[WorkflowOutboundInterceptor]] = None, ): self._logger = Logger('WorkflowRuntime', logger_options) - metadata = tuple() + metadata = () if settings.DAPR_API_TOKEN: metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),) address = getAddress(host, port) @@ -80,16 +100,128 @@ def __init__( maximum_thread_pool_workers=maximum_thread_pool_workers, ), ) + # Interceptors + self._runtime_interceptors: List[RuntimeInterceptor] = list(runtime_interceptors or []) + self._workflow_outbound_interceptors: List[WorkflowOutboundInterceptor] = list( + workflow_outbound_interceptors or [] + ) + + # Outbound helpers apply interceptors and wrap metadata; no built-in transformations. + def _apply_outbound_activity( + self, + ctx: Any, + activity: Callable[..., Any] | str, + input: Any, + retry_policy: Any | None, + metadata: dict[str, str] | None = None, + ): + # Build workflow-outbound chain to transform CallActivityRequest + name = ( + activity + if isinstance(activity, str) + else ( + activity.__dict__['_dapr_alternate_name'] + if hasattr(activity, '_dapr_alternate_name') + else activity.__name__ + ) + ) + + def terminal(term_req: CallActivityRequest) -> CallActivityRequest: + return term_req + + chain = compose_workflow_outbound_chain(self._workflow_outbound_interceptors, terminal) + # Use per-context default metadata when not provided + metadata = metadata or ctx.get_metadata() + act_req = CallActivityRequest( + activity_name=name, + input=input, + retry_policy=retry_policy, + workflow_ctx=ctx, + metadata=metadata, + ) + out = chain(act_req) + if isinstance(out, CallActivityRequest): + return wrap_payload_with_metadata(out.input, out.metadata) + return input + + def _apply_outbound_child( + self, + ctx: Any, + workflow: Callable[..., Any] | str, + input: Any, + metadata: dict[str, str] | None = None, + ): + name = ( + workflow + if isinstance(workflow, str) + else ( + workflow.__dict__['_dapr_alternate_name'] + if hasattr(workflow, '_dapr_alternate_name') + else workflow.__name__ + ) + ) + + def terminal(term_req: CallChildWorkflowRequest) -> CallChildWorkflowRequest: + return term_req + + chain = compose_workflow_outbound_chain(self._workflow_outbound_interceptors, terminal) + metadata = metadata or ctx.get_metadata() + child_req = CallChildWorkflowRequest( + workflow_name=name, input=input, instance_id=None, workflow_ctx=ctx, metadata=metadata + ) + out = chain(child_req) + if isinstance(out, CallChildWorkflowRequest): + return wrap_payload_with_metadata(out.input, out.metadata) + return input + + def _apply_outbound_continue_as_new( + self, + ctx: Any, + new_input: Any, + metadata: dict[str, str] | None = None, + ): + # Build workflow-outbound chain to transform ContinueAsNewRequest + from dapr.ext.workflow.interceptors import ContinueAsNewRequest + + def terminal(term_req: ContinueAsNewRequest) -> ContinueAsNewRequest: + return term_req + + chain = compose_workflow_outbound_chain(self._workflow_outbound_interceptors, terminal) + metadata = metadata or ctx.get_metadata() + cnr = ContinueAsNewRequest(input=new_input, workflow_ctx=ctx, metadata=metadata) + out = chain(cnr) + if isinstance(out, ContinueAsNewRequest): + return wrap_payload_with_metadata(out.input, out.metadata) + return new_input def register_workflow(self, fn: Workflow, *, name: Optional[str] = None): + # Seamlessly support async workflows using the existing API + if inspect.iscoroutinefunction(fn): + return self.register_async_workflow(fn, name=name) + self._logger.info(f"Registering workflow '{fn.__name__}' with runtime") - def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): - """Responsible to call Workflow function in orchestrationWrapper""" - daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) - if inp is None: - return fn(daprWfContext) - return fn(daprWfContext, inp) + def orchestration_wrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): + """Orchestration entrypoint wrapped by runtime interceptors.""" + payload, md = unwrap_payload_with_metadata(inp) + dapr_wf_context = self._get_workflow_context(ctx, md) + + # Build interceptor chain; terminal calls the user function (generator or non-generator) + def final_handler(exec_req: ExecuteWorkflowRequest) -> Any: + try: + return ( + fn(dapr_wf_context) + if exec_req.input is None + else fn(dapr_wf_context, exec_req.input) + ) + except Exception as exc: # log and re-raise to surface failure details + self._logger.error( + f"{ctx.instance_id}: workflow '{fn.__name__}' raised {type(exc).__name__}: {exc}\n{traceback.format_exc()}" + ) + raise + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain(ExecuteWorkflowRequest(ctx=dapr_wf_context, input=payload, metadata=md)) if hasattr(fn, '_workflow_registered'): # whenever a workflow is registered, it has a _dapr_alternate_name attribute @@ -104,7 +236,7 @@ def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ self.__worker._registry.add_named_orchestrator( - fn.__dict__['_dapr_alternate_name'], orchestrationWrapper + fn.__dict__['_dapr_alternate_name'], orchestration_wrapper ) fn.__dict__['_workflow_registered'] = True @@ -114,12 +246,44 @@ def register_activity(self, fn: Activity, *, name: Optional[str] = None): """ self._logger.info(f"Registering activity '{fn.__name__}' with runtime") - def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): - """Responsible to call Activity function in activityWrapper""" - wfActivityContext = WorkflowActivityContext(ctx) - if inp is None: - return fn(wfActivityContext) - return fn(wfActivityContext, inp) + def activity_wrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): + """Activity entrypoint wrapped by runtime interceptors.""" + wf_activity_context = WorkflowActivityContext(ctx) + payload, md = unwrap_payload_with_metadata(inp) + # Populate inbound metadata onto activity context + wf_activity_context.set_metadata(md or {}) + + # Populate execution info + try: + # Determine activity name (registered alternate name or function __name__) + act_name = getattr(fn, '_dapr_alternate_name', fn.__name__) + ainfo = ActivityExecutionInfo(inbound_metadata=md or {}, activity_name=act_name) + wf_activity_context._set_execution_info(ainfo) + except Exception: + pass + + def final_handler(exec_req: ExecuteActivityRequest) -> Any: + try: + # Support async and sync activities + if inspect.iscoroutinefunction(fn): + if exec_req.input is None: + return asyncio.run(fn(wf_activity_context)) + return asyncio.run(fn(wf_activity_context, exec_req.input)) + if exec_req.input is None: + return fn(wf_activity_context) + return fn(wf_activity_context, exec_req.input) + except Exception as exc: + # Log details for troubleshooting (metadata, error type) + self._logger.error( + f"{ctx.orchestration_id}:{ctx.task_id} activity '{fn.__name__}' failed with {type(exc).__name__}: {exc}" + ) + self._logger.error(traceback.format_exc()) + raise + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain( + ExecuteActivityRequest(ctx=wf_activity_context, input=payload, metadata=md) + ) if hasattr(fn, '_activity_registered'): # whenever an activity is registered, it has a _dapr_alternate_name attribute @@ -134,7 +298,7 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ self.__worker._registry.add_named_activity( - fn.__dict__['_dapr_alternate_name'], activityWrapper + fn.__dict__['_dapr_alternate_name'], activity_wrapper ) fn.__dict__['_activity_registered'] = True @@ -144,7 +308,13 @@ def start(self): def shutdown(self): """Stops the listening for work items on a background thread.""" - self.__worker.stop() + try: + self._logger.info('Stopping gRPC worker...') + self.__worker.stop() + self._logger.info('Worker shutdown completed') + except Exception as exc: # pragma: no cover + # DurableTask worker may emit CANCELLED warnings during local shutdown; not fatal + self._logger.warning(f'Worker stop encountered {type(exc).__name__}: {exc}') def workflow(self, __fn: Workflow = None, *, name: Optional[str] = None): """Decorator to register a workflow function. @@ -174,7 +344,11 @@ def add(ctx, x: int, y: int) -> int: """ def wrapper(fn: Workflow): - self.register_workflow(fn, name=name) + # Auto-detect coroutine and delegate to async registration + if inspect.iscoroutinefunction(fn): + self.register_async_workflow(fn, name=name) + else: + self.register_workflow(fn, name=name) @wraps(fn) def innerfn(): @@ -194,6 +368,121 @@ def innerfn(): return wrapper + # Async orchestrator registration (additive) + def register_async_workflow( + self, + fn: Callable[[AsyncWorkflowContext, Any], Awaitable[Any]], + *, + name: Optional[str] = None, + sandbox_mode: SandboxMode = SandboxMode.OFF, + ) -> None: + """Register an async workflow function. + + The async workflow is wrapped by a coroutine-to-generator driver so it can be + executed by the Durable Task runtime alongside existing generator workflows. + + Args: + fn: The async workflow function, taking ``AsyncWorkflowContext`` and optional input. + name: Optional alternate name for registration. + sandbox_mode: Scoped compatibility patching mode. + """ + self._logger.info(f"Registering ASYNC workflow '{fn.__name__}' with runtime") + + if hasattr(fn, '_workflow_registered'): + alt_name = fn.__dict__['_dapr_alternate_name'] + raise ValueError(f'Workflow {fn.__name__} already registered as {alt_name}') + if hasattr(fn, '_dapr_alternate_name'): + alt_name = fn._dapr_alternate_name + if name is not None: + m = f'Workflow {fn.__name__} already has an alternate name {alt_name}' + raise ValueError(m) + else: + fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + + runner = CoroutineOrchestratorRunner(fn, sandbox_mode=sandbox_mode) + + def generator_orchestrator(ctx: task.OrchestrationContext, inp: Optional[Any] = None): + """Orchestration entrypoint wrapped by runtime interceptors.""" + payload, md = unwrap_payload_with_metadata(inp) + base_ctx = self._get_workflow_context(ctx, md) + + async_ctx = AsyncWorkflowContext(base_ctx) + + def final_handler(exec_req: ExecuteWorkflowRequest) -> Any: + # Build the generator using the (potentially shaped) input from interceptors. + shaped_input = exec_req.input + return runner.to_generator(async_ctx, shaped_input) + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain(ExecuteWorkflowRequest(ctx=async_ctx, input=payload, metadata=md)) + + self.__worker._registry.add_named_orchestrator( + fn.__dict__['_dapr_alternate_name'], generator_orchestrator + ) + fn.__dict__['_workflow_registered'] = True + + def _get_workflow_context( + self, ctx: task.OrchestrationContext, metadata: dict[str, str] | None = None + ) -> DaprWorkflowContext: + """Get the workflow context and execution info for the given orchestration context and metadata. + Execution info serves as a read-only snapshot of the workflow context. + + Args: + ctx: The orchestration context. + metadata: The metadata for the workflow. + + Returns: + The workflow context. + """ + base_ctx = DaprWorkflowContext( + ctx, + self._logger.get_options(), + outbound_handlers={ + Handlers.CALL_ACTIVITY: self._apply_outbound_activity, + Handlers.CALL_CHILD_WORKFLOW: self._apply_outbound_child, + Handlers.CONTINUE_AS_NEW: self._apply_outbound_continue_as_new, + }, + ) + # Populate minimal execution info (only inbound metadata) + info = WorkflowExecutionInfo(inbound_metadata=metadata or {}) + base_ctx._set_execution_info(info) + base_ctx.set_metadata(metadata or {}) + return base_ctx + + def async_workflow( + self, + __fn: Callable[[AsyncWorkflowContext, Any], Awaitable[Any]] = None, + *, + name: Optional[str] = None, + sandbox_mode: SandboxMode = SandboxMode.OFF, + ): + """Decorator to register an async workflow function. + + Usage: + @runtime.async_workflow(name="my_wf") + async def my_wf(ctx: AsyncWorkflowContext, input): + ... + """ + + def wrapper(fn: Callable[[AsyncWorkflowContext, Any], Awaitable[Any]]): + self.register_async_workflow(fn, name=name, sandbox_mode=sandbox_mode) + + @wraps(fn) + def innerfn(): + return fn + + if hasattr(fn, '_dapr_alternate_name'): + innerfn.__dict__['_dapr_alternate_name'] = fn.__dict__['_dapr_alternate_name'] + else: + innerfn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + innerfn.__signature__ = inspect.signature(fn) + return innerfn + + if __fn: + return wrapper(__fn) + + return wrapper + def activity(self, __fn: Activity = None, *, name: Optional[str] = None): """Decorator to register an activity function. diff --git a/ext/dapr-ext-workflow/examples/generics_interceptors_example.py b/ext/dapr-ext-workflow/examples/generics_interceptors_example.py new file mode 100644 index 00000000..7ef9c8f4 --- /dev/null +++ b/ext/dapr-ext-workflow/examples/generics_interceptors_example.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import os +from dataclasses import asdict, dataclass +from typing import List + +from dapr.ext.workflow import ( + DaprWorkflowClient, + WorkflowRuntime, +) +from dapr.ext.workflow.interceptors import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ContinueAsNewRequest, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + ScheduleWorkflowRequest, +) + +# ------------------------------ +# Typed payloads carried by interceptors +# ------------------------------ + + +@dataclass +class MyWorkflowInput: + question: str + tags: List[str] + + +@dataclass +class MyActivityInput: + name: str + count: int + + +# ------------------------------ +# Interceptors with generics + minimal (de)serialization +# ------------------------------ + + +class MyClientInterceptor(BaseClientInterceptor[MyWorkflowInput]): + def schedule_new_workflow( + self, + input: ScheduleWorkflowRequest[MyWorkflowInput], + nxt, + ) -> str: + # Ensure wire format is JSON-serializable (dict) + payload = ( + asdict(input.input) if hasattr(input.input, '__dataclass_fields__') else input.input + ) + shaped = ScheduleWorkflowRequest[MyWorkflowInput]( + workflow_name=input.workflow_name, + input=payload, # type: ignore[arg-type] + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + metadata=input.metadata, + ) + return nxt(shaped) + + +class MyRuntimeInterceptor(BaseRuntimeInterceptor[MyWorkflowInput, MyActivityInput]): + def execute_workflow( + self, + input: ExecuteWorkflowRequest[MyWorkflowInput], + nxt, + ): + # Convert inbound dict into typed model for workflow code + data = input.input + if isinstance(data, dict) and 'question' in data: + input.input = MyWorkflowInput( + question=data.get('question', ''), tags=list(data.get('tags', [])) + ) # type: ignore[assignment] + return nxt(input) + + def execute_activity( + self, + input: ExecuteActivityRequest[MyActivityInput], + nxt, + ): + data = input.input + if isinstance(data, dict) and 'name' in data: + input.input = MyActivityInput( + name=data.get('name', ''), count=int(data.get('count', 0)) + ) # type: ignore[assignment] + return nxt(input) + + +class MyOutboundInterceptor(BaseWorkflowOutboundInterceptor[MyWorkflowInput, MyActivityInput]): + def call_child_workflow( + self, + input: CallChildWorkflowRequest[MyWorkflowInput], + nxt, + ): + # Convert typed payload back to wire before sending + payload = ( + asdict(input.input) if hasattr(input.input, '__dataclass_fields__') else input.input + ) + shaped = CallChildWorkflowRequest[MyWorkflowInput]( + workflow_name=input.workflow_name, + input=payload, # type: ignore[arg-type] + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + ) + return nxt(shaped) + + def continue_as_new( + self, + input: ContinueAsNewRequest[MyWorkflowInput], + nxt, + ): + payload = ( + asdict(input.input) if hasattr(input.input, '__dataclass_fields__') else input.input + ) + shaped = ContinueAsNewRequest[MyWorkflowInput]( + input=payload, # type: ignore[arg-type] + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + ) + return nxt(shaped) + + def call_activity( + self, + input: CallActivityRequest[MyActivityInput], + nxt, + ): + payload = ( + asdict(input.input) if hasattr(input.input, '__dataclass_fields__') else input.input + ) + shaped = CallActivityRequest[MyActivityInput]( + activity_name=input.activity_name, + input=payload, # type: ignore[arg-type] + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + ) + return nxt(shaped) + + +# ------------------------------ +# Minimal runnable example with sidecar +# ------------------------------ + + +def main() -> None: + # Expect DAPR_GRPC_ENDPOINT (e.g., dns:127.0.0.1:56179) to be set for local sidecar/dev hub + ep = os.getenv('DAPR_GRPC_ENDPOINT') + if not ep: + print('WARNING: DAPR_GRPC_ENDPOINT not set; default sidecar address will be used') + + # Build runtime with interceptors + runtime = WorkflowRuntime( + runtime_interceptors=[MyRuntimeInterceptor()], + workflow_outbound_interceptors=[MyOutboundInterceptor()], + ) + + # Register a simple activity + @runtime.activity(name='greet') + def greet(_ctx, x: dict | None = None) -> str: # wire format at activity boundary is dict + x = x or {} + return f'Hello {x.get("name", "world")} x{x.get("count", 0)}' + + # Register an async workflow that calls the activity once + @runtime.async_workflow(name='wf_greet') + async def wf_greet(ctx, arg: MyWorkflowInput | dict | None = None): + # At this point, runtime interceptor converted inbound to MyWorkflowInput + if isinstance(arg, MyWorkflowInput): + act_in = MyActivityInput(name=arg.question, count=len(arg.tags)) + else: + # Fallback if interceptor not present + d = arg or {} + act_in = MyActivityInput(name=str(d.get('question', '')), count=len(d.get('tags', []))) + return await ctx.call_activity('greet', input=asdict(act_in)) + + runtime.start() + try: + # Client with client-side interceptor for schedule typing + client = DaprWorkflowClient(interceptors=[MyClientInterceptor()]) + wf_input = MyWorkflowInput(question='World', tags=['a', 'b']) + instance_id = client.schedule_new_workflow(wf_greet, input=wf_input) + print('Started instance:', instance_id) + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + print('Final status:', getattr(st, 'runtime_status', None)) + if st: + print('Output:', st.to_json().get('serialized_output')) + finally: + runtime.shutdown() + + +if __name__ == '__main__': + main() diff --git a/ext/dapr-ext-workflow/tests/README.md b/ext/dapr-ext-workflow/tests/README.md new file mode 100644 index 00000000..6759a362 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/README.md @@ -0,0 +1,94 @@ +## Workflow tests: unit, integration, and custom ports + +This directory contains unit tests (no sidecar required) and integration tests (require a running sidecar/runtime). + +### Prereqs + +- Python 3.11+ (tox will create an isolated venv) +- Dapr sidecar for integration tests (HTTP and gRPC ports) +- Optional: Durable Task gRPC endpoint for DT e2e tests + +### Run all tests via tox (recommended) + +```bash +tox -e py311 +``` + +This runs: +- Core SDK tests (unittest) +- Workflow extension unit tests (pytest) +- Workflow extension integration tests (pytest) if your sidecar/runtime is reachable + +### Run only workflow unit tests + +Unit tests live at `ext/dapr-ext-workflow/tests` excluding the `integration/` subfolder. + +With tox: +```bash +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests -k "not integration" +``` + +Directly (outside tox): +```bash +pytest -q ext/dapr-ext-workflow/tests -k "not integration" +``` + +### Run workflow integration tests + +Integration tests live under `ext/dapr-ext-workflow/tests/integration/` and require a running sidecar/runtime. + +With tox: +```bash +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests/integration +``` + +Directly (outside tox): +```bash +pytest -q ext/dapr-ext-workflow/tests/integration +``` + +If tests cannot reach your sidecar/runtime, they will skip or fail fast depending on the specific test. + +### Configure custom sidecar ports/endpoints + +The SDK reads connection settings from env vars (see `dapr.conf.global_settings`). Use these to point tests at custom ports: + +- Dapr gRPC: + - `DAPR_GRPC_ENDPOINT` (preferred): endpoint string, e.g. `dns:127.0.0.1:50051` + - or `DAPR_RUNTIME_HOST` and `DAPR_GRPC_PORT`, e.g. `DAPR_RUNTIME_HOST=127.0.0.1`, `DAPR_GRPC_PORT=50051` + +- Dapr HTTP (only for HTTP-based tests): + - `DAPR_HTTP_ENDPOINT`: e.g. `http://127.0.0.1:3600` + - or `DAPR_RUNTIME_HOST` and `DAPR_HTTP_PORT`, e.g. `DAPR_HTTP_PORT=3600` + +Examples: +```bash +# Use custom gRPC 50051 and HTTP 3600 +export DAPR_GRPC_ENDPOINT=dns:127.0.0.1:50051 +export DAPR_HTTP_ENDPOINT=http://127.0.0.1:3600 + +# Alternatively, using host/port pairs +export DAPR_RUNTIME_HOST=127.0.0.1 +export DAPR_GRPC_PORT=50051 +export DAPR_HTTP_PORT=3600 + +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests/integration +``` + +Note: For gRPC, avoid `http://` or `https://` schemes. Use `dns:host:port` or just set host/port separately. + +### Durable Task e2e tests (optional) + +Some tests (e.g., `integration/test_async_e2e_dt.py`) talk directly to a Durable Task gRPC endpoint. They use: + +- `DURABLETASK_GRPC_ENDPOINT` (default `localhost:56178`) + +If your DT runtime listens elsewhere: +```bash +export DURABLETASK_GRPC_ENDPOINT=127.0.0.1:56179 +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py +``` + + + + diff --git a/ext/dapr-ext-workflow/tests/_fakes.py b/ext/dapr-ext-workflow/tests/_fakes.py new file mode 100644 index 00000000..09051702 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/_fakes.py @@ -0,0 +1,72 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + + +class FakeOrchestrationContext: + def __init__( + self, + *, + instance_id: str = 'wf-1', + current_utc_datetime: datetime | None = None, + is_replaying: bool = False, + workflow_name: str = 'wf', + parent_instance_id: str | None = None, + history_event_sequence: int | None = 1, + trace_parent: str | None = None, + trace_state: str | None = None, + orchestration_span_id: str | None = None, + workflow_attempt: int | None = None, + ) -> None: + self.instance_id = instance_id + self.current_utc_datetime = ( + current_utc_datetime if current_utc_datetime else datetime(2025, 1, 1) + ) + self.is_replaying = is_replaying + self.workflow_name = workflow_name + self.parent_instance_id = parent_instance_id + self.history_event_sequence = history_event_sequence + self.trace_parent = trace_parent + self.trace_state = trace_state + self.orchestration_span_id = orchestration_span_id + self.workflow_attempt = workflow_attempt + + +class FakeActivityContext: + def __init__( + self, + *, + orchestration_id: str = 'wf-1', + task_id: int = 1, + attempt: int | None = None, + trace_parent: str | None = None, + trace_state: str | None = None, + workflow_span_id: str | None = None, + ) -> None: + self.orchestration_id = orchestration_id + self.task_id = task_id + self.trace_parent = trace_parent + self.trace_state = trace_state + self.workflow_span_id = workflow_span_id + + +def make_orch_ctx(**overrides: Any) -> FakeOrchestrationContext: + return FakeOrchestrationContext(**overrides) + + +def make_act_ctx(**overrides: Any) -> FakeActivityContext: + return FakeActivityContext(**overrides) diff --git a/ext/dapr-ext-workflow/tests/conftest.py b/ext/dapr-ext-workflow/tests/conftest.py new file mode 100644 index 00000000..f20a225e --- /dev/null +++ b/ext/dapr-ext-workflow/tests/conftest.py @@ -0,0 +1,74 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Ensure tests prefer the local python-sdk repository over any installed site-packages +# This helps when running pytest directly (outside tox/CI), so changes in the repo are exercised. +from __future__ import annotations # noqa: I001 + +import sys +from pathlib import Path +import importlib +import pytest + + +def pytest_configure(config): # noqa: D401 (pytest hook) + """Pytest configuration hook that prepends the repo root to sys.path. + + This ensures `import dapr` resolves to the local source tree when running tests directly. + Under tox/CI (editable installs), this is a no-op but still safe. + """ + try: + # ext/dapr-ext-workflow/tests/conftest.py -> repo root is 3 parents up + repo_root = Path(__file__).resolve().parents[3] + except Exception: + return + + repo_str = str(repo_root) + if repo_str not in sys.path: + sys.path.insert(0, repo_str) + + # Best-effort diagnostic: show where dapr was imported from + try: + dapr_mod = importlib.import_module('dapr') + dapr_path = Path(getattr(dapr_mod, '__file__', '')).resolve() + where = 'site-packages' if 'site-packages' in str(dapr_path) else 'local-repo' + print(f'[dapr-ext-workflow/tests] dapr resolved from {where}: {dapr_path}', file=sys.stderr) + except Exception: + # If dapr isn't importable yet, that's fine; tests importing it later will use modified sys.path + pass + + +@pytest.fixture(autouse=True) +def cleanup_workflow_registrations(request): + """Clean up workflow/activity registration markers after each test. + + This prevents test interference when the same function objects are reused across tests. + The workflow runtime marks functions with _dapr_alternate_name and _activity_registered + attributes, which can cause 'already registered' errors in subsequent tests. + """ + yield # Run the test + + # After test completes, clean up functions defined in the test module + test_module = sys.modules.get(request.module.__name__) + if test_module: + for name in dir(test_module): + obj = getattr(test_module, name, None) + if callable(obj) and hasattr(obj, '__dict__'): + try: + # Only clean up if __dict__ is writable (not mappingproxy) + if isinstance(obj.__dict__, dict): + obj.__dict__.pop('_dapr_alternate_name', None) + obj.__dict__.pop('_activity_registered', None) + except (AttributeError, TypeError): + # Skip objects with read-only __dict__ + pass diff --git a/ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py b/ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py new file mode 100644 index 00000000..5c9feca6 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- + +""" +Async e2e tests using durabletask worker/client directly. + +These validate basic orchestration behavior against a running sidecar +to isolate environment issues from WorkflowRuntime wiring. +""" + +from __future__ import annotations + +import os +import time + +import pytest +from durabletask.aio import AsyncWorkflowContext +from durabletask.client import TaskHubGrpcClient +from durabletask.worker import TaskHubGrpcWorker + +pytestmark = pytest.mark.e2e + + +def _is_runtime_available(ep_str: str) -> bool: + import socket + + try: + host, port = ep_str.split(':') + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + result = sock.connect_ex((host, int(port))) + sock.close() + return result == 0 + except Exception: + return False + + +endpoint = os.getenv('DAPR_GRPC_ENDPOINT', 'localhost:50001') + +skip_if_no_runtime = pytest.mark.skipif( + not _is_runtime_available(endpoint), + reason='DurableTask runtime not available', +) + + +@skip_if_no_runtime +def test_dt_simple_activity_e2e(): + # using global read-only endpoint variable + worker = TaskHubGrpcWorker(host_address=endpoint) + client = TaskHubGrpcClient(host_address=endpoint) + + def act(ctx, x: int) -> int: + return x * 3 + + worker.add_activity(act) + + @worker.add_async_orchestrator + async def orch(ctx: AsyncWorkflowContext, x: int) -> int: + return await ctx.call_activity(act, input=x) + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-act-{int(time.time() * 1000)}' + client.schedule_new_orchestration(orch, input=5, instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + # Output is JSON serialized scalar + assert st.serialized_output.strip() in ('15', '"15"') + finally: + try: + worker.stop() + except Exception: + pass + + +@skip_if_no_runtime +def test_dt_timer_e2e(): + # using global read-only endpoint variable + worker = TaskHubGrpcWorker(host_address=endpoint) + client = TaskHubGrpcClient(host_address=endpoint) + + @worker.add_async_orchestrator + async def orch(ctx: AsyncWorkflowContext, delay: float) -> dict: + start = ctx.now() + await ctx.sleep(delay) + end = ctx.now() + return {'start': start.isoformat(), 'end': end.isoformat(), 'delay': delay} + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-timer-{int(time.time() * 1000)}' + delay = 1.0 + client.schedule_new_orchestration(orch, input=delay, instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + finally: + try: + worker.stop() + except Exception: + pass + + +@skip_if_no_runtime +def test_dt_sub_orchestrator_e2e(): + # using global read-only endpoint variable + worker = TaskHubGrpcWorker(host_address=endpoint) + client = TaskHubGrpcClient(host_address=endpoint) + + def act(ctx, s: str) -> str: + return f'A:{s}' + + worker.add_activity(act) + + async def child(ctx: AsyncWorkflowContext, s: str) -> str: + print('[E2E DEBUG] child start', s) + try: + res = await ctx.call_activity(act, input=s) + print('[E2E DEBUG] child done', res) + return res + except Exception as exc: # pragma: no cover - troubleshooting aid + import traceback as _tb + + print('[E2E DEBUG] child exception:', type(exc).__name__, str(exc)) + print(_tb.format_exc()) + raise + + # Explicit registration to avoid decorator replacing symbol with a string in newer versions + worker.add_async_orchestrator(child) + + async def parent(ctx: AsyncWorkflowContext, s: str) -> str: + print('[E2E DEBUG] parent start', s) + try: + c = await ctx.call_sub_orchestrator(child, input=s) + out = f'P:{c}' + print('[E2E DEBUG] parent done', out) + return out + except Exception as exc: # pragma: no cover - troubleshooting aid + import traceback as _tb + + print('[E2E DEBUG] parent exception:', type(exc).__name__, str(exc)) + print(_tb.format_exc()) + raise + + worker.add_async_orchestrator(parent) + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-sub-{int(time.time() * 1000)}' + print('[E2E DEBUG] scheduling instance', iid) + client.schedule_new_orchestration(parent, input='x', instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + if st.runtime_status.name != 'COMPLETED': + # Print orchestration state details to aid debugging + print('[E2E DEBUG] orchestration FAILED; details:') + to_json = getattr(st, 'to_json', None) + if callable(to_json): + try: + print(to_json()) + except Exception: + pass + print('status=', getattr(st, 'runtime_status', None)) + print('output=', getattr(st, 'serialized_output', None)) + print('failure=', getattr(st, 'failure_details', None)) + assert st.runtime_status.name == 'COMPLETED' + finally: + try: + worker.stop() + except Exception: + pass diff --git a/ext/dapr-ext-workflow/tests/integration/test_integration_async_semantics.py b/ext/dapr-ext-workflow/tests/integration/test_integration_async_semantics.py new file mode 100644 index 00000000..cafb125d --- /dev/null +++ b/ext/dapr-ext-workflow/tests/integration/test_integration_async_semantics.py @@ -0,0 +1,965 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +import time + +import pytest +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + DaprWorkflowContext, + WorkflowRuntime, +) +from dapr.ext.workflow.interceptors import ( + BaseRuntimeInterceptor, + ExecuteActivityRequest, + ExecuteWorkflowRequest, +) + +pytestmark = pytest.mark.e2e + +skip_integration = pytest.mark.skipif( + False, + reason='integration enabled', +) + + +@skip_integration +def test_integration_suspension_and_buffering(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='suspend_orchestrator_async') + async def suspend_orchestrator(ctx: AsyncWorkflowContext): + # Expose suspension state via custom status + ctx.set_custom_status({'is_suspended': getattr(ctx, 'is_suspended', False)}) + # Wait for 'resume_event' and then complete + data = await ctx.wait_for_external_event('resume_event') + return {'resumed_with': data} + + runtime.start() + try: + # Allow connection to stabilize before scheduling + time.sleep(3) + + client = DaprWorkflowClient() + instance_id = f'suspend-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=suspend_orchestrator, instance_id=instance_id) + + # Wait until started + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + + # Pause and verify state becomes SUSPENDED and custom status updates on next activation + client.pause_workflow(instance_id) + # Give the worker time to process suspension + time.sleep(1) + state = client.get_workflow_state(instance_id) + assert state is not None + assert state.runtime_status.name in ( + 'SUSPENDED', + 'RUNNING', + ) # some hubs report SUSPENDED explicitly + + # While suspended, raise the event; it should buffer + client.raise_workflow_event(instance_id, 'resume_event', data={'ok': True}) + + # Resume and expect completion + client.resume_workflow(instance_id) + final = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + assert final is not None + assert final.runtime_status.name == 'COMPLETED' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_generator_metadata_propagation(): + runtime = WorkflowRuntime() + + @runtime.activity(name='recv_md_gen') + def recv_md_gen(ctx, _=None): + return ctx.get_metadata() or {} + + @runtime.workflow(name='gen_parent_sets_md') + def parent_gen(ctx: DaprWorkflowContext): + ctx.set_metadata({'tenant': 'acme', 'tier': 'gold'}) + md = yield ctx.call_activity(recv_md_gen, input=None) + return md + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'gen-md-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent_gen, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(state.to_json().get('serialized_output') or '{}') + assert out.get('tenant') == 'acme' + assert out.get('tier') == 'gold' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_trace_context_child_workflow(): + runtime = WorkflowRuntime() + + @runtime.activity(name='trace_probe') + def trace_probe(ctx, _=None): + return { + 'tp': getattr(ctx, 'trace_parent', None), + 'ts': getattr(ctx, 'trace_state', None), + 'wf_span': getattr(ctx, 'workflow_span_id', None), + } + + @runtime.async_workflow(name='child_trace') + async def child(ctx: AsyncWorkflowContext, _=None): + return { + 'wf_tp': getattr(ctx, 'trace_parent', None), + 'wf_ts': getattr(ctx, 'trace_state', None), + 'wf_span': getattr(ctx, 'workflow_span_id', None), + 'act': await ctx.call_activity(trace_probe, input=None), + } + + @runtime.async_workflow(name='parent_trace') + async def parent(ctx: AsyncWorkflowContext): + child_out = await ctx.call_child_workflow(child, input=None) + return { + 'parent_tp': getattr(ctx, 'trace_parent', None), + 'parent_span': getattr(ctx, 'workflow_span_id', None), + 'child': child_out, + } + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'trace-child-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + data = _json.loads(st.to_json().get('serialized_output') or '{}') + + # TODO: assert more specifically when we have trace context information + + # Parent (engine-provided fields may be absent depending on runtime build/config) + assert isinstance(data.get('parent_tp'), (str, type(None))) + assert isinstance(data.get('parent_span'), (str, type(None))) + # Child orchestrator fields + _child = data.get('child') or {} + assert isinstance(_child.get('wf_tp'), (str, type(None))) + assert isinstance(_child.get('wf_span'), (str, type(None))) + # Activity fields under child + act = _child.get('act') or {} + assert isinstance(act.get('tp'), (str, type(None))) + assert isinstance(act.get('wf_span'), (str, type(None))) + + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_trace_context_child_workflow_injected_metadata(): + # Deterministic trace propagation using interceptors via durable metadata + from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ScheduleWorkflowRequest, + ) + + TRACE_KEY = 'otel.trace_id' + + class InjectTraceClient(BaseClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'sdk-trace-123') + return next( + ScheduleWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + metadata=md, + ) + ) + + class InjectTraceOutbound(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, next): + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'sdk-trace-123') + return next( + CallActivityRequest( + activity_name=request.activity_name, + input=request.input, + retry_policy=request.retry_policy, + workflow_ctx=request.workflow_ctx, + metadata=md, + ) + ) + + def call_child_workflow(self, request: CallChildWorkflowRequest, next): + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'sdk-trace-123') + return next( + CallChildWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + workflow_ctx=request.workflow_ctx, + metadata=md, + ) + ) + + class RestoreTraceRuntime(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Ensure metadata arrives + assert isinstance((request.metadata or {}).get(TRACE_KEY), str) + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): + assert isinstance((request.metadata or {}).get(TRACE_KEY), str) + return next(request) + + runtime = WorkflowRuntime( + runtime_interceptors=[RestoreTraceRuntime()], + workflow_outbound_interceptors=[InjectTraceOutbound()], + ) + + @runtime.activity(name='trace_probe2') + def trace_probe2(ctx, _=None): + return getattr(ctx, 'get_metadata', lambda: {})().get(TRACE_KEY) + + @runtime.async_workflow(name='child_trace2') + async def child2(ctx: AsyncWorkflowContext, _=None): + return { + 'wf_md': (ctx.get_metadata() or {}).get(TRACE_KEY), + 'act_md': await ctx.call_activity(trace_probe2, input=None), + } + + @runtime.async_workflow(name='parent_trace2') + async def parent2(ctx: AsyncWorkflowContext): + out = await ctx.call_child_workflow(child2, input=None) + return { + 'parent_md': (ctx.get_metadata() or {}).get(TRACE_KEY), + 'child': out, + } + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient(interceptors=[InjectTraceClient()]) + iid = f'trace-child-md-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent2, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + data = _json.loads(st.to_json().get('serialized_output') or '{}') + assert data.get('parent_md') == 'sdk-trace-123' + child = data.get('child') or {} + assert child.get('wf_md') == 'sdk-trace-123' + assert child.get('act_md') == 'sdk-trace-123' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_termination_semantics(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='termination_orchestrator_async') + async def termination_orchestrator(ctx: AsyncWorkflowContext): + # Long timer; test will terminate before it fires + await ctx.create_timer(300.0) + return 'not-reached' + + print(list(runtime._WorkflowRuntime__worker._registry.orchestrators.keys())) + + runtime.start() + try: + time.sleep(3) + + client = DaprWorkflowClient() + instance_id = f'term-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=termination_orchestrator, instance_id=instance_id) + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + + # Terminate and assert TERMINATED state, not raising inside orchestrator + client.terminate_workflow(instance_id, output='terminated') + final = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + assert final is not None + assert final.runtime_status.name == 'TERMINATED' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_when_any_first_wins(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='when_any_async') + async def when_any_orchestrator(ctx: AsyncWorkflowContext): + first = await ctx.when_any( + [ + ctx.wait_for_external_event('go'), + ctx.create_timer(300.0), + ] + ) + # Return a simple, serializable value (winner's result) to avoid output serialization issues + try: + result = first.get_result() + except Exception: + result = None + return {'winner_result': result} + + runtime.start() + try: + # Ensure worker has established streams before scheduling + try: + if hasattr(runtime, 'wait_for_ready'): + runtime.wait_for_ready(timeout=15) # type: ignore[attr-defined] + except Exception: + pass + time.sleep(2) + + client = DaprWorkflowClient() + instance_id = f'whenany-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=when_any_orchestrator, instance_id=instance_id) + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + # Confirm RUNNING state before raising event (mitigates race conditions) + try: + st = client.get_workflow_state(instance_id, fetch_payloads=False) + if ( + st is None + or getattr(st, 'runtime_status', None) is None + or st.runtime_status.name != 'RUNNING' + ): + end = time.time() + 10 + while time.time() < end: + st = client.get_workflow_state(instance_id, fetch_payloads=False) + if ( + st is not None + and getattr(st, 'runtime_status', None) is not None + and st.runtime_status.name == 'RUNNING' + ): + break + time.sleep(0.2) + except Exception: + pass + + # Raise event immediately to win the when_any + client.raise_workflow_event(instance_id, 'go', data={'ok': True}) + + # Brief delay to allow event processing, then strictly use DaprWorkflowClient + time.sleep(1.0) + final = None + try: + final = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + except TimeoutError: + final = None + if final is None: + deadline = time.time() + 30 + while time.time() < deadline: + s = client.get_workflow_state(instance_id, fetch_payloads=False) + if s is not None and getattr(s, 'runtime_status', None) is not None: + if s.runtime_status.name in ('COMPLETED', 'FAILED', 'TERMINATED'): + final = s + break + time.sleep(0.5) + assert final is not None + assert final.runtime_status.name == 'COMPLETED' + # TODO: when sidecar exposes command diagnostics, assert only one command set was emitted + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_async_activity_completes(): + runtime = WorkflowRuntime() + + @runtime.activity(name='echo_int') + def echo_act(ctx, x: int) -> int: + return x + + @runtime.async_workflow(name='async_activity_once') + async def wf(ctx: AsyncWorkflowContext): + out = await ctx.call_activity(echo_act, input=7) + return out + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'act-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + if state.runtime_status.name != 'COMPLETED': + fd = getattr(state, 'failure_details', None) + msg = getattr(fd, 'message', None) if fd else None + et = getattr(fd, 'error_type', None) if fd else None + print(f'[INTEGRATION DEBUG] Failure details: {et} {msg}') + assert state.runtime_status.name == 'COMPLETED' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_metadata_outbound_to_activity(): + runtime = WorkflowRuntime() + + @runtime.activity(name='recv_md') + def recv_md(ctx, _=None): + md = ctx.get_metadata() if hasattr(ctx, 'get_metadata') else {} + return md + + @runtime.async_workflow(name='wf_with_md') + async def wf(ctx: AsyncWorkflowContext): + ctx.set_metadata({'tenant': 'acme'}) + md = await ctx.call_activity(recv_md, input=None) + return md + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'md-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_metadata_outbound_to_child_workflow(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='child_recv_md') + async def child(ctx: AsyncWorkflowContext, _=None): + # Echo inbound metadata + return ctx.get_metadata() or {} + + @runtime.async_workflow(name='parent_sets_md') + async def parent(ctx: AsyncWorkflowContext): + ctx.set_metadata({'tenant': 'acme', 'role': 'user'}) + out = await ctx.call_child_workflow(child, input=None) + return out + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'md-child-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + # Validate output has metadata keys + data = state.to_json() + import json as _json + + out = _json.loads(data.get('serialized_output') or '{}') + assert out.get('tenant') == 'acme' + assert out.get('role') == 'user' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_trace_context_with_runtime_interceptors(): + """E2E: Verify trace_parent and orchestration_span_id via runtime interceptors.""" + records = { # captured by interceptor + 'wf_tp': None, + 'wf_span': None, + 'act_tp': None, + 'act_span': None, + } + + class TraceInterceptor(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + ctx = request.ctx + try: + records['wf_tp'] = getattr(ctx, 'trace_parent', None) + records['wf_span'] = getattr(ctx, 'workflow_span_id', None) + except Exception: + pass + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): # type: ignore[override] + ctx = request.ctx + try: + records['act_tp'] = getattr(ctx, 'trace_parent', None) + # Activity contexts don't have orchestration_span_id; capture task span if present + records['act_span'] = getattr(ctx, 'activity_span_id', None) + except Exception: + pass + return next(request) + + runtime = WorkflowRuntime(runtime_interceptors=[TraceInterceptor()]) + + @runtime.activity(name='trace_probe') + def trace_probe(ctx, _=None): + # Return trace context seen inside activity + return { + 'trace_parent': getattr(ctx, 'trace_parent', None), + 'trace_state': getattr(ctx, 'trace_state', None), + } + + @runtime.async_workflow(name='trace_parent_wf') + async def wf(ctx: AsyncWorkflowContext): + # Access orchestration span id and trace parent from workflow context + _ = getattr(ctx, 'workflow_span_id', None) + _ = getattr(ctx, 'trace_parent', None) + return await ctx.call_activity(trace_probe, input=None) + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'trace-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(state.to_json().get('serialized_output') or '{}') + # Activity returned strings (may be empty); assert types + assert isinstance(out.get('trace_parent'), (str, type(None))) + assert isinstance(out.get('trace_state'), (str, type(None))) + # Interceptor captured workflow and activity contexts + wf_tp = records['wf_tp'] + wf_span = records['wf_span'] + act_tp = records['act_tp'] + # TODO: assert more specifically when we have trace context information + assert isinstance(wf_tp, (str, type(None))) + assert isinstance(wf_span, (str, type(None))) + assert isinstance(act_tp, (str, type(None))) + # If we have a workflow span id, it should appear as parent-id inside activity traceparent + if isinstance(wf_span, str) and wf_span and isinstance(act_tp, str) and act_tp: + assert wf_span.lower() in act_tp.lower() + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_runtime_shutdown_is_clean(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='noop') + async def noop(ctx: AsyncWorkflowContext): + return 'ok' + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient() + iid = f'shutdown-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=noop, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=30) + assert st is not None and st.runtime_status.name == 'COMPLETED' + finally: + # Call shutdown multiple times to ensure idempotent and clean behavior + for _ in range(3): + try: + runtime.shutdown() + except Exception: + # Test should not raise even if worker logs cancellation warnings + assert False, 'runtime.shutdown() raised unexpectedly' + # Recreate and shutdown again to ensure no lingering background threads break next startup + rt2 = WorkflowRuntime() + rt2.start() + try: + time.sleep(1) + finally: + try: + rt2.shutdown() + except Exception: + assert False, 'second runtime.shutdown() raised unexpectedly' + + +@skip_integration +def test_integration_continue_as_new_outbound_interceptor_metadata(): + # Verify continue_as_new outbound interceptor can inject metadata carried to the new run + from dapr.ext.workflow import BaseWorkflowOutboundInterceptor + + INJECT_KEY = 'injected' + + class InjectOnContinueAsNew(BaseWorkflowOutboundInterceptor): + def continue_as_new(self, request, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(INJECT_KEY, 'yes') + request.metadata = md + return next(request) + + runtime = WorkflowRuntime( + workflow_outbound_interceptors=[InjectOnContinueAsNew()], + ) + + @runtime.workflow(name='continue_as_new_probe') + def wf(ctx, arg: dict | None = None): + if not arg or arg.get('phase') != 'second': + ctx.set_metadata({'tenant': 'acme'}) + # carry over existing metadata; interceptor will also inject + ctx.continue_as_new({'phase': 'second'}, carryover_metadata=True) + return # Must not yield after continue_as_new + # Second run: return inbound metadata observed + return ctx.get_metadata() or {} + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient() + iid = f'can-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + # Confirm both carried and injected metadata are present + assert out.get('tenant') == 'acme' + assert out.get(INJECT_KEY) == 'yes' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_child_workflow_attempt_exposed(): + # Verify that child workflow ctx exposes workflow_attempt + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='child_probe_attempt') + async def child_probe_attempt(ctx: AsyncWorkflowContext, _=None): + att = getattr(ctx, 'workflow_attempt', None) + return {'wf_attempt': att} + + @runtime.async_workflow(name='parent_calls_child_for_attempt') + async def parent_calls_child_for_attempt(ctx: AsyncWorkflowContext): + return await ctx.call_child_workflow(child_probe_attempt, input=None) + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient() + iid = f'child-attempt-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent_calls_child_for_attempt, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + val = out.get('wf_attempt', None) + assert (val is None) or isinstance(val, int) + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_async_contextvars_trace_propagation(monkeypatch): + # Demonstrates contextvars-based trace propagation via interceptors in async workflows + import contextvars + import json as _json + + from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ScheduleWorkflowRequest, + ) + + TRACE_KEY = 'otel.trace_ctx' + current_trace: contextvars.ContextVar[str | None] = contextvars.ContextVar( + 'trace', default=None + ) + + class CVClient(BaseClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'wf-parent') + return next( + ScheduleWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + metadata=md, + ) + ) + + class CVOutbound(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, current_trace.get()) + return next( + CallActivityRequest( + activity_name=request.activity_name, + input=request.input, + retry_policy=request.retry_policy, + workflow_ctx=request.workflow_ctx, + metadata=md, + ) + ) + + def call_child_workflow(self, request: CallChildWorkflowRequest, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, current_trace.get()) + return next( + CallChildWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + workflow_ctx=request.workflow_ctx, + metadata=md, + ) + ) + + class CVRuntime(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + prev = current_trace.set((request.metadata or {}).get(TRACE_KEY)) + try: + return next(request) + finally: + current_trace.reset(prev) + + def execute_activity(self, request: ExecuteActivityRequest, next): # type: ignore[override] + prev = current_trace.set((request.metadata or {}).get(TRACE_KEY)) + try: + return next(request) + finally: + current_trace.reset(prev) + + runtime = WorkflowRuntime( + runtime_interceptors=[CVRuntime()], workflow_outbound_interceptors=[CVOutbound()] + ) + + @runtime.activity(name='cv_probe') + def cv_probe(_ctx, _=None): + before = current_trace.get() + tok = current_trace.set(f'{before}/act') if before else None + try: + inner = current_trace.get() + finally: + if tok is not None: + current_trace.reset(tok) + after = current_trace.get() + return {'before': before, 'inner': inner, 'after': after} + + flaky_call_count = [0] + + @runtime.activity(name='cv_flaky_probe') + def cv_flaky_probe(ctx, _=None): + before = current_trace.get() + flaky_call_count[0] += 1 + print(f'----------> flaky_call_count: {flaky_call_count[0]}') + if flaky_call_count[0] == 1: + # Fail first attempt to trigger retry + raise Exception('fail-once') + tok = current_trace.set(f'{before}/act-retry') if before else None + try: + inner = current_trace.get() + finally: + if tok is not None: + current_trace.reset(tok) + after = current_trace.get() + return {'before': before, 'inner': inner, 'after': after} + + @runtime.async_workflow(name='cv_child') + async def cv_child(ctx: AsyncWorkflowContext, _=None): + before = current_trace.get() + tok = current_trace.set(f'{before}/child') if before else None + try: + act = await ctx.call_activity(cv_probe, input=None) + finally: + if tok is not None: + current_trace.reset(tok) + restored = current_trace.get() + return {'before': before, 'restored': restored, 'act': act} + + @runtime.async_workflow(name='cv_parent') + async def cv_parent(ctx: AsyncWorkflowContext, _=None): + from datetime import timedelta + + from dapr.ext.workflow import RetryPolicy + + top_before = current_trace.get() + child = await ctx.call_child_workflow(cv_child, input=None) + after_child = current_trace.get() + act = await ctx.call_activity(cv_probe, input=None) + after_act = current_trace.get() + act_retry = await ctx.call_activity( + cv_flaky_probe, + input=None, + retry_policy=RetryPolicy( + first_retry_interval=timedelta(seconds=0), max_number_of_attempts=3 + ), + ) + return { + 'before': top_before, + 'child': child, + 'act': act, + 'act_retry': act_retry, + 'after_child': after_child, + 'after_act': after_act, + } + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient(interceptors=[CVClient()]) + iid = f'cv-ctx-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=cv_parent, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + out = _json.loads(st.to_json().get('serialized_output') or '{}') + # Top-level activity sees parent trace context during execution + act = out.get('act') or {} + assert act.get('before') == 'wf-parent' + assert act.get('inner') == 'wf-parent/act' + assert act.get('after') == 'wf-parent' + # Child workflow's activity at least inherits parent context + child = out.get('child') or {} + child_act = child.get('act') or {} + assert child_act.get('before') == 'wf-parent' + assert child_act.get('inner') == 'wf-parent/act' + assert child_act.get('after') == 'wf-parent' + # Flaky activity retried: second attempt succeeds and returns with parent context + act_retry = out.get('act_retry') or {} + assert act_retry.get('before') == 'wf-parent' + assert act_retry.get('inner') == 'wf-parent/act-retry' + assert act_retry.get('after') == 'wf-parent' + finally: + runtime.shutdown() + + +def test_runtime_interceptor_shapes_async_input(): + runtime = WorkflowRuntime() + + class ShapeInput(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + data = request.input + # Mutate input passed to workflow + if isinstance(data, dict): + shaped = {**data, 'shaped': True} + else: + shaped = {'value': data, 'shaped': True} + request.input = shaped + return next(request) + + # Recreate runtime with interceptor wired in + runtime = WorkflowRuntime(runtime_interceptors=[ShapeInput()]) + + @runtime.async_workflow(name='wf_shape_input') + async def wf_shape_input(ctx: AsyncWorkflowContext, arg: dict | None = None): + # Verify shaped input is observed by the workflow + return arg + + runtime.start() + try: + from dapr.ext.workflow import DaprWorkflowClient + + client = DaprWorkflowClient() + iid = f'shape-{id(runtime)}' + client.schedule_new_workflow(workflow=wf_shape_input, instance_id=iid, input={'x': 1}) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + assert out.get('x') == 1 + assert out.get('shaped') is True + finally: + runtime.shutdown() + + +def test_runtime_interceptor_context_manager_with_async_workflow(): + """Test that context managers stay active during async workflow execution.""" + runtime = WorkflowRuntime() + + # Track when context enters and exits + context_state = {'entered': False, 'exited': False, 'workflow_ran': False} + + class ContextInterceptor(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + # Wrapper generator to keep context manager alive + def wrapper(): + from contextlib import ExitStack + + with ExitStack(): + # Mark context as entered + context_state['entered'] = True + + # Get the workflow generator + gen = next(request) + + # Use yield from to keep context alive during execution + yield from gen + + # Context will exit after generator completes + context_state['exited'] = True + + return wrapper() + + runtime = WorkflowRuntime(runtime_interceptors=[ContextInterceptor()]) + + @runtime.async_workflow(name='wf_context_test') + async def wf_context_test(ctx: AsyncWorkflowContext, arg: dict | None = None): + context_state['workflow_ran'] = True + return {'result': 'ok'} + + runtime.start() + try: + from dapr.ext.workflow import DaprWorkflowClient + + client = DaprWorkflowClient() + iid = f'ctx-test-{id(runtime)}' + client.schedule_new_workflow(workflow=wf_context_test, instance_id=iid, input={}) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + + # Verify context manager was active during workflow execution + assert context_state['entered'], 'Context should have been entered' + assert context_state['workflow_ran'], 'Workflow should have executed' + assert context_state['exited'], 'Context should have exited after completion' + finally: + runtime.shutdown() diff --git a/ext/dapr-ext-workflow/tests/test_async_activity_registration.py b/ext/dapr-ext-workflow/tests/test_async_activity_registration.py new file mode 100644 index 00000000..cbffdd5a --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_activity_registration.py @@ -0,0 +1,58 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime + + +class _FakeRegistry: + def __init__(self): + self.activities = {} + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +def test_activity_decorator_supports_async(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.activity(name='async_act') + async def async_act(ctx, x: int) -> int: + return x + 2 + + # Ensure registered + reg = rt._WorkflowRuntime__worker._registry + assert 'async_act' in reg.activities + + # Call the wrapper and ensure it runs the coroutine to completion + wrapper = reg.activities['async_act'] + + class _Ctx: + pass + + out = wrapper(_Ctx(), 5) + assert out == 7 diff --git a/ext/dapr-ext-workflow/tests/test_async_activity_retry_failure.py b/ext/dapr-ext-workflow/tests/test_async_activity_retry_failure.py new file mode 100644 index 00000000..6433735f --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_activity_retry_failure.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'iid-act-retry' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None, app_id=None): + return FakeTask('activity') + + def create_timer(self, fire_at): + return FakeTask('timer') + + +async def wf(ctx: AsyncWorkflowContext): + # One activity that ultimately fails after retries + await ctx.call_activity(lambda: None, retry_policy={'dummy': True}) + return 'not-reached' + + +def test_activity_retry_final_failure_raises(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # Prime + next(gen) + # Simulate final failure after retry policy exhausts + with pytest.raises(RuntimeError, match='activity failed'): + gen.throw(RuntimeError('activity failed')) diff --git a/ext/dapr-ext-workflow/tests/test_async_api_coverage.py b/ext/dapr-ext-workflow/tests/test_async_api_coverage.py new file mode 100644 index 00000000..94004446 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_api_coverage.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime + +from dapr.ext.workflow.aio import AsyncWorkflowContext + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = datetime(2024, 1, 1) + self.instance_id = 'iid-cov' + self._status = None + + def set_custom_status(self, status): + self._status = status + + def continue_as_new(self, new_input, *, save_events=False): + self._continued = (new_input, save_events) + + # methods used by awaitables + def call_activity(self, activity, *, input=None, retry_policy=None, app_id=None): + class _T: + pass + + return _T() + + def call_child_workflow(self, workflow, *, input=None, instance_id=None, retry_policy=None, app_id=None): + class _T: + pass + + return _T() + + def create_timer(self, fire_at): + class _T: + pass + + return _T() + + def wait_for_external_event(self, name: str): + class _T: + pass + + return _T() + + +def test_async_context_exposes_required_methods(): + base = FakeCtx() + ctx = AsyncWorkflowContext(base) + + # basic deterministic utils existence + assert isinstance(ctx.now(), datetime) + _ = ctx.random() + _ = ctx.uuid4() + + # pass-throughs + ctx.set_custom_status('ok') + assert base._status == 'ok' + ctx.continue_as_new({'foo': 1}, save_events=True) + assert getattr(base, '_continued', None) == ({'foo': 1}, True) + + # awaitable constructors do not raise + ctx.call_activity(lambda: None, input={'x': 1}) + ctx.call_child_workflow(lambda: None) + ctx.sleep(1.0) + ctx.wait_for_external_event('go') + ctx.when_all([]) + ctx.when_any([]) diff --git a/ext/dapr-ext-workflow/tests/test_async_concurrency_and_determinism.py b/ext/dapr-ext-workflow/tests/test_async_concurrency_and_determinism.py new file mode 100644 index 00000000..4da50ded --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_concurrency_and_determinism.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from durabletask import task as durable_task_module +from durabletask.deterministic import deterministic_random, deterministic_uuid4 + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = datetime(2024, 1, 1) + self.instance_id = 'iid-123' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None, app_id=None): + return FakeTask(f'activity:{getattr(activity, "__name__", str(activity))}') + + def call_child_workflow( + self, workflow, *, input=None, instance_id=None, retry_policy=None, metadata=None, app_id=None + ): + return FakeTask(f'sub:{getattr(workflow, "__name__", str(workflow))}') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +def drive_first_wins(gen, winner_name): + # Simulate when_any: first send the winner, then finish + next(gen) # prime + result = gen.send({'task': winner_name}) + # the coroutine should complete; StopIteration will be raised by caller + return result + + +async def wf_when_all(ctx: AsyncWorkflowContext): + a = ctx.call_activity(lambda: None) + b = ctx.sleep(1.0) + res = await ctx.when_all([a, b]) + return res + + +def test_when_all_maps_and_completes(monkeypatch): + # Patch durabletask.when_all to accept our FakeTask inputs and return a FakeTask + monkeypatch.setattr(durable_task_module, 'when_all', lambda tasks: FakeTask('when_all')) + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_when_all) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # Drive two yields: when_all yields a task once; we simply return a list result + try: + t = gen.send(None) + assert isinstance(t, FakeTask) + out = gen.send([{'task': 'activity:lambda'}, {'task': 'timer'}]) + except StopIteration as stop: + out = stop.value + assert isinstance(out, list) + assert len(out) == 2 + + +async def wf_when_any(ctx: AsyncWorkflowContext): + a = ctx.call_activity(lambda: None) + b = ctx.sleep(5.0) + first = await ctx.when_any([a, b]) + # Return the first result only; losers ignored deterministically + return first + + +def test_when_any_first_wins_behavior(monkeypatch): + monkeypatch.setattr(durable_task_module, 'when_any', lambda tasks: FakeTask('when_any')) + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_when_any) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + try: + t = gen.send(None) + assert isinstance(t, FakeTask) + out = gen.send({'task': 'activity:lambda'}) + except StopIteration as stop: + out = stop.value + assert out == {'task': 'activity:lambda'} + + +def test_deterministic_random_and_uuid_are_stable(): + iid = 'iid-123' + now = datetime(2024, 1, 1) + rnd1 = deterministic_random(iid, now) + rnd2 = deterministic_random(iid, now) + seq1 = [rnd1.random() for _ in range(5)] + seq2 = [rnd2.random() for _ in range(5)] + assert seq1 == seq2 + u1 = deterministic_uuid4(deterministic_random(iid, now)) + u2 = deterministic_uuid4(deterministic_random(iid, now)) + assert u1 == u2 diff --git a/ext/dapr-ext-workflow/tests/test_async_context.py b/ext/dapr-ext-workflow/tests/test_async_context.py new file mode 100644 index 00000000..1ca0e57a --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_context.py @@ -0,0 +1,202 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import types +from datetime import datetime, timedelta, timezone + +from dapr.ext.workflow import AsyncWorkflowContext +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext +from dapr.ext.workflow.workflow_context import WorkflowContext + + +class DummyBaseCtx: + def __init__(self): + self.instance_id = 'abc-123' + # freeze a deterministic timestamp + self.current_utc_datetime = datetime(2025, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + self.is_replaying = False + self._custom_status = None + self._continued = None + self._metadata = None + self._ei = types.SimpleNamespace( + workflow_id='abc-123', + workflow_name='wf', + is_replaying=False, + history_event_sequence=1, + inbound_metadata={'a': 'b'}, + parent_instance_id=None, + ) + + def set_custom_status(self, s: str): + self._custom_status = s + + def continue_as_new(self, new_input, *, save_events: bool = False): + self._continued = (new_input, save_events) + + # Metadata parity + def set_metadata(self, md): + self._metadata = md + + def get_metadata(self): + return self._metadata + + @property + def execution_info(self): + return self._ei + + +def test_parity_properties_and_now(): + ctx = AsyncWorkflowContext(DummyBaseCtx()) + assert ctx.instance_id == 'abc-123' + assert ctx.current_utc_datetime == datetime(2025, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + # now() should mirror current_utc_datetime + assert ctx.now() == ctx.current_utc_datetime + + +def test_timer_accepts_float_and_timedelta(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + + # Float should be interpreted as seconds and produce a SleepAwaitable + aw1 = ctx.create_timer(1.5) + # Timedelta should pass through + aw2 = ctx.create_timer(timedelta(seconds=2)) + + # We only assert types by duck-typing public attribute presence to avoid + # importing internal classes in tests + assert hasattr(aw1, '_ctx') and hasattr(aw1, '__await__') + assert hasattr(aw2, '_ctx') and hasattr(aw2, '__await__') + + +def test_wait_for_external_event_and_concurrency_factories(): + ctx = AsyncWorkflowContext(DummyBaseCtx()) + + evt = ctx.wait_for_external_event('go') + assert hasattr(evt, '__await__') + + # when_all/when_any/gather return awaitables + a = ctx.create_timer(0.1) + b = ctx.create_timer(0.2) + + all_aw = ctx.when_all([a, b]) + any_aw = ctx.when_any([a, b]) + gat_aw = ctx.gather(a, b) + gat_exc_aw = ctx.gather(a, b, return_exceptions=True) + + for x in (all_aw, any_aw, gat_aw, gat_exc_aw): + assert hasattr(x, '__await__') + + +def test_deterministic_utils_and_passthroughs(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + + rnd = ctx.random() + # should behave like a random.Random-like object; test a stable first value + val = rnd.random() + # Just assert it is within (0,1) and stable across two calls to the seeded RNG instance + assert 0.0 < val < 1.0 + assert rnd.random() != val # next value changes + + uid = ctx.uuid4() + # Should be a UUID-like string representation + assert isinstance(str(uid), str) and len(str(uid)) >= 32 + + # passthroughs + ctx.set_custom_status('hello') + assert base._custom_status == 'hello' + + ctx.continue_as_new({'x': 1}, save_events=True) + assert base._continued == ({'x': 1}, True) + + +def test_async_metadata_api_and_execution_info(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + ctx.set_metadata({'k': 'v'}) + assert base._metadata == {'k': 'v'} + assert ctx.get_metadata() == {'k': 'v'} + ei = ctx.execution_info + assert ei and ei.workflow_id == 'abc-123' and ei.workflow_name == 'wf' + + +def test_async_outbound_metadata_plumbed_into_awaitables(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + a = ctx.call_activity(lambda: None, input=1, metadata={'m': 'n'}) + c = ctx.call_child_workflow(lambda c, x: None, input=2, metadata={'x': 'y'}) + # Introspect for test (internal attribute) + assert getattr(a, '_metadata', None) == {'m': 'n'} + assert getattr(c, '_metadata', None) == {'x': 'y'} + + +def test_async_parity_surface_exists(): + # Guard: ensure essential parity members exist + ctx = AsyncWorkflowContext(DummyBaseCtx()) + for name in ( + 'set_metadata', + 'get_metadata', + 'execution_info', + 'call_activity', + 'call_child_workflow', + 'continue_as_new', + ): + assert hasattr(ctx, name) + + +def test_public_api_parity_against_workflowcontext_abc(): + # Derive the required sync API surface from the ABC plus metadata/execution_info + required = { + name + for name, attr in WorkflowContext.__dict__.items() + if getattr(attr, '__isabstractmethod__', False) + } + required.update({'set_metadata', 'get_metadata', 'execution_info'}) + + # Async context must expose the same names + async_ctx = AsyncWorkflowContext(DummyBaseCtx()) + missing_in_async = [name for name in required if not hasattr(async_ctx, name)] + assert not missing_in_async, f'AsyncWorkflowContext missing: {missing_in_async}' + + # Sync context should also expose these names + class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'abc-123' + self.current_utc_datetime = datetime(2025, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + self.is_replaying = False + + def set_custom_status(self, s: str): + pass + + def create_timer(self, fire_at): + return object() + + def wait_for_external_event(self, name: str): + return object() + + def continue_as_new(self, new_input, *, save_events: bool = False): + pass + + def call_activity( + self, *, activity, input=None, retry_policy=None, app_id: str | None = None + ): + return object() + + def call_sub_orchestrator( + self, fn, *, input=None, instance_id=None, retry_policy=None, app_id: str | None = None + ): + return object() + + sync_ctx = DaprWorkflowContext(_FakeOrchCtx()) + missing_in_sync = [name for name in required if not hasattr(sync_ctx, name)] + assert not missing_in_sync, f'DaprWorkflowContext missing: {missing_in_sync}' diff --git a/ext/dapr-ext-workflow/tests/test_async_errors_and_backcompat.py b/ext/dapr-ext-workflow/tests/test_async_errors_and_backcompat.py new file mode 100644 index 00000000..2a5c8147 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_errors_and_backcompat.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeOrchestrationContext: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'iid-errors' + self.is_replaying = False + self._custom_status = None + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None, app_id=None): + return FakeTask('activity') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + def set_custom_status(self, custom_status): + self._custom_status = custom_status + + +def drive_raise(gen, exc: Exception): + # Prime + task = gen.send(None) + assert isinstance(task, FakeTask) + # Simulate runtime failure of yielded task + try: + gen.throw(exc) + except StopIteration as stop: + return stop.value + + +async def wf_catches_activity_error(ctx: AsyncWorkflowContext): + try: + await ctx.call_activity(lambda: (_ for _ in ()).throw(RuntimeError('boom'))) + except RuntimeError as e: + return f'caught:{e}' + return 'not-reached' + + +def test_activity_error_propagates_into_coroutine_and_can_be_caught(): + fake = FakeOrchestrationContext() + runner = CoroutineOrchestratorRunner(wf_catches_activity_error) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive_raise(gen, RuntimeError('boom')) + assert result == 'caught:boom' + + +async def wf_returns_sync(ctx: AsyncWorkflowContext): + return 42 + + +def test_sync_return_is_handled_without_runtime_error(): + fake = FakeOrchestrationContext() + runner = CoroutineOrchestratorRunner(wf_returns_sync) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # Prime and complete + try: + gen.send(None) + except StopIteration as stop: + assert stop.value == 42 + + +class _FakeRegistry: + def __init__(self): + self.orchestrators = {} + self.activities = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +def test_generator_and_async_registration_coexist(monkeypatch): + # Monkeypatch TaskHubGrpcWorker to avoid real gRPC + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.workflow(name='gen_wf') + def gen(ctx): + yield ctx.create_timer(0) + return 'ok' + + async def async_wf(ctx: AsyncWorkflowContext): + await ctx.sleep(0) + return 'ok' + + rt.register_async_workflow(async_wf, name='async_wf') + + # Verify registry got both entries + reg = rt._WorkflowRuntime__worker._registry + assert 'gen_wf' in reg.orchestrators + assert 'async_wf' in reg.orchestrators + + # Drive generator orchestrator wrapper + gen_fn = reg.orchestrators['gen_wf'] + g = gen_fn(FakeOrchestrationContext()) + t = next(g) + assert isinstance(t, FakeTask) + try: + g.send(None) + except StopIteration as stop: + assert stop.value == 'ok' + + # Also verify CancelledError propagates and can be caught + import asyncio + + async def wf_cancel(ctx: AsyncWorkflowContext): + try: + await ctx.call_activity(lambda: None) + except asyncio.CancelledError: + return 'cancelled' + return 'not-reached' + + runner = CoroutineOrchestratorRunner(wf_cancel) + gen_2 = runner.to_generator(AsyncWorkflowContext(FakeOrchestrationContext()), None) + # prime + next(gen_2) + try: + gen_2.throw(asyncio.CancelledError()) + except StopIteration as stop: + assert stop.value == 'cancelled' diff --git a/ext/dapr-ext-workflow/tests/test_async_registration_via_workflow.py b/ext/dapr-ext-workflow/tests/test_async_registration_via_workflow.py new file mode 100644 index 00000000..f1df6720 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_registration_via_workflow.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.aio import AsyncWorkflowContext +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime + + +class _FakeRegistry: + def __init__(self): + self.orchestrators = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +def test_workflow_decorator_detects_async_and_registers(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.workflow(name='async_wf') + async def async_wf(ctx: AsyncWorkflowContext, x: int) -> int: + # no awaits to keep simple + return x + 1 + + # ensure it was placed into registry + reg = rt._WorkflowRuntime__worker._registry + assert 'async_wf' in reg.orchestrators diff --git a/ext/dapr-ext-workflow/tests/test_async_replay.py b/ext/dapr-ext-workflow/tests/test_async_replay.py new file mode 100644 index 00000000..ed6a52f4 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_replay.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime, timedelta + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self, instance_id: str = 'iid-replay', now: datetime | None = None): + self.current_utc_datetime = now or datetime(2024, 1, 1) + self.instance_id = instance_id + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None, app_id=None): + return FakeTask(f'activity:{getattr(activity, "__name__", str(activity))}:{input}') + + def create_timer(self, fire_at): + return FakeTask(f'timer:{fire_at}') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +def drive_with_history(gen, results): + """Drive the generator with a pre-baked sequence of results, simulating replay history.""" + try: + next(gen) + idx = 0 + while True: + gen.send(results[idx]) + idx += 1 + except StopIteration as stop: + return stop.value + + +async def wf_mixed(ctx: AsyncWorkflowContext): + # activity + r1 = await ctx.call_activity(lambda: None, input={'x': 1}) + # timer + await ctx.sleep(timedelta(seconds=5)) + # event + e = await ctx.wait_for_external_event('go') + # deterministic utils + t = ctx.now() + u = str(ctx.uuid4()) + return {'a': r1, 'e': e, 't': t.isoformat(), 'u': u} + + +def test_replay_same_history_same_outputs(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_mixed) + # Pre-bake results sequence corresponding to activity -> timer -> event + history = [ + {'task': "activity:lambda:{'x': 1}"}, + None, + {'event': 42}, + ] + out1 = drive_with_history(runner.to_generator(AsyncWorkflowContext(fake), None), history) + out2 = drive_with_history(runner.to_generator(AsyncWorkflowContext(fake), None), history) + assert out1 == out2 diff --git a/ext/dapr-ext-workflow/tests/test_async_sandbox.py b/ext/dapr-ext-workflow/tests/test_async_sandbox.py new file mode 100644 index 00000000..d19ef30e --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_sandbox.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import random +import time + +import pytest +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from durabletask.aio.errors import SandboxViolationError +from durabletask.aio.sandbox import SandboxMode + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = __import__('datetime').datetime(2024, 1, 1) + self.instance_id = 'iid-sandbox' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None, app_id=None): + return FakeTask('activity') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +async def wf_sleep(ctx: AsyncWorkflowContext): + # asyncio.sleep should be patched to workflow timer + await asyncio.sleep(0.1) + return 'ok' + + +def drive(gen, first_result=None): + try: + task = gen.send(None) + assert isinstance(task, FakeTask) + result = first_result + while True: + task = gen.send(result) + assert isinstance(task, FakeTask) + result = None + except StopIteration as stop: + return stop.value + + +def test_sandbox_best_effort_patches_sleep(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_sleep, sandbox_mode=SandboxMode.BEST_EFFORT) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen) + assert result == 'ok' + + +def test_sandbox_random_uuid_time_are_deterministic(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner( + lambda ctx: _wf_random_uuid_time(ctx), sandbox_mode=SandboxMode.BEST_EFFORT + ) + gen1 = runner.to_generator(AsyncWorkflowContext(fake), None) + out1 = drive(gen1) + gen2 = runner.to_generator(AsyncWorkflowContext(fake), None) + out2 = drive(gen2) + assert out1 == out2 + + +async def _wf_random_uuid_time(ctx: AsyncWorkflowContext): + r1 = random.random() + u1 = __import__('uuid').uuid4() + t1 = time.time(), getattr(time, 'time_ns', lambda: int(time.time() * 1_000_000_000))() + # no awaits needed; return tuple + return (r1, str(u1), t1[0], t1[1]) + + +def test_strict_blocks_create_task(): + async def wf(ctx: AsyncWorkflowContext): + with pytest.raises(SandboxViolationError): + asyncio.create_task(asyncio.sleep(0)) + return 'ok' + + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf, sandbox_mode=SandboxMode.STRICT) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen) + assert result == 'ok' diff --git a/ext/dapr-ext-workflow/tests/test_async_sub_orchestrator.py b/ext/dapr-ext-workflow/tests/test_async_sub_orchestrator.py new file mode 100644 index 00000000..08f23d77 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_sub_orchestrator.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'iid-sub' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None, app_id=None): + return FakeTask('activity') + + def call_child_workflow( + self, workflow, *, input=None, instance_id=None, retry_policy=None, metadata=None, app_id=None + ): + return FakeTask(f'sub:{getattr(workflow, "__name__", str(workflow))}') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +def drive_success(gen, results): + try: + next(gen) + idx = 0 + while True: + gen.send(results[idx]) + idx += 1 + except StopIteration as stop: + return stop.value + + +def drive_raise(gen, exc: Exception): + # Prime + next(gen) + # Throw failure into orchestrator + return pytest.raises(Exception, gen.throw, exc) + + +async def child(ctx: AsyncWorkflowContext, n: int) -> int: + return n * 2 + + +async def parent_success(ctx: AsyncWorkflowContext): + res = await ctx.call_child_workflow(child, input=3) + return res + 1 + + +def test_sub_orchestrator_success(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(parent_success) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # First yield is the sub-orchestrator task + result = drive_success(gen, results=[6]) + assert result == 7 + + +async def parent_failure(ctx: AsyncWorkflowContext): + # Do not catch; allow failure to propagate + await ctx.call_child_workflow(child, input=1) + return 'not-reached' + + +def test_sub_orchestrator_failure_raises_into_orchestrator(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(parent_failure) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # Prime and then throw into the coroutine to simulate child failure + next(gen) + with pytest.raises(RuntimeError, match='child failed'): + gen.throw(RuntimeError('child failed')) diff --git a/ext/dapr-ext-workflow/tests/test_async_when_any_losers_policy.py b/ext/dapr-ext-workflow/tests/test_async_when_any_losers_policy.py new file mode 100644 index 00000000..99fa242d --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_when_any_losers_policy.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from durabletask import task as durable_task_module + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'iid-any' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None, app_id=None): + return FakeTask('activity') + + def create_timer(self, fire_at): + return FakeTask('timer') + + +async def wf_when_any(ctx: AsyncWorkflowContext): + # Two awaitables: an activity and a timer + a = ctx.call_activity(lambda: None) + b = ctx.sleep(10) + first = await ctx.when_any([a, b]) + return first + + +def test_when_any_yields_once_and_returns_first_result(monkeypatch): + # Patch durabletask.when_any to avoid requiring real durabletask.Task objects + monkeypatch.setattr(durable_task_module, 'when_any', lambda tasks: FakeTask('when_any')) + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_when_any) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + + # Prime; expect a single composite yield + yielded = gen.send(None) + assert isinstance(yielded, FakeTask) + # Send the 'first' completion; generator should complete without yielding again + try: + gen.send({'task': 'activity'}) + raise AssertionError('generator should have completed') + except StopIteration as stop: + assert stop.value == {'task': 'activity'} diff --git a/ext/dapr-ext-workflow/tests/test_async_workflow_basic.py b/ext/dapr-ext-workflow/tests/test_async_workflow_basic.py new file mode 100644 index 00000000..0f176e79 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_workflow_basic.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = __import__('datetime').datetime(2024, 1, 1) + self.instance_id = 'test-instance' + self._events: dict[str, list] = {} + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None, app_id=None): + return FakeTask(f'activity:{getattr(activity, "__name__", str(activity))}') + + def call_child_workflow( + self, workflow, *, input=None, instance_id=None, retry_policy=None, metadata=None, app_id=None + ): + return FakeTask(f'sub:{getattr(workflow, "__name__", str(workflow))}') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +def drive(gen, first_result=None): + """Drive a generator produced by the async driver, emulating the runtime.""" + try: + task = gen.send(None) + assert isinstance(task, FakeTask) + result = first_result + while True: + task = gen.send(result) + assert isinstance(task, FakeTask) + # Provide a generic result for every yield + result = {'task': task.name} + except StopIteration as stop: + return stop.value + + +async def sample_activity(ctx: AsyncWorkflowContext): + return await ctx.call_activity(lambda: None) + + +def test_activity_awaitable_roundtrip(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(sample_activity) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen, first_result={'task': 'activity:lambda'}) + assert result == {'task': 'activity:lambda'} + + +async def sample_timer(ctx: AsyncWorkflowContext): + await ctx.create_timer(1.0) + return 'done' + + +def test_timer_awaitable_roundtrip(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(sample_timer) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen, first_result=None) + assert result == 'done' + + +async def sample_event(ctx: AsyncWorkflowContext): + data = await ctx.wait_for_external_event('go') + return ('event', data) + + +def test_event_awaitable_roundtrip(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(sample_event) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen, first_result={'hello': 'world'}) + assert result == ('event', {'hello': 'world'}) diff --git a/ext/dapr-ext-workflow/tests/test_deterministic.py b/ext/dapr-ext-workflow/tests/test_deterministic.py new file mode 100644 index 00000000..fa76f22f --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_deterministic.py @@ -0,0 +1,74 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import datetime as _dt + +import pytest +from dapr.ext.workflow.aio import AsyncWorkflowContext +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext + +""" +Tests for deterministic helpers shared across workflow contexts. +""" + + +class _FakeBaseCtx: + def __init__(self, instance_id: str, dt: _dt.datetime): + self.instance_id = instance_id + self.current_utc_datetime = dt + + +def _fixed_dt(): + return _dt.datetime(2024, 1, 1) + + +def test_random_string_deterministic_across_instances_async(): + base = _FakeBaseCtx('iid-1', _fixed_dt()) + a_ctx = AsyncWorkflowContext(base) + b_ctx = AsyncWorkflowContext(base) + a = a_ctx.random_string(16) + b = b_ctx.random_string(16) + assert a == b + + +def test_random_string_deterministic_across_context_types(): + base = _FakeBaseCtx('iid-2', _fixed_dt()) + a_ctx = AsyncWorkflowContext(base) + s1 = a_ctx.random_string(12) + + # Minimal fake orchestration context for DaprWorkflowContext + d_ctx = DaprWorkflowContext(base) + s2 = d_ctx.random_string(12) + assert s1 == s2 + + +def test_random_string_respects_alphabet(): + base = _FakeBaseCtx('iid-3', _fixed_dt()) + ctx = AsyncWorkflowContext(base) + s = ctx.random_string(20, alphabet='abc') + assert set(s).issubset(set('abc')) + + +def test_random_string_length_and_edge_cases(): + base = _FakeBaseCtx('iid-4', _fixed_dt()) + ctx = AsyncWorkflowContext(base) + + assert ctx.random_string(0) == '' + + with pytest.raises(ValueError): + ctx.random_string(-1) + + with pytest.raises(ValueError): + ctx.random_string(5, alphabet='') diff --git a/ext/dapr-ext-workflow/tests/test_generic_serialization.py b/ext/dapr-ext-workflow/tests/test_generic_serialization.py new file mode 100644 index 00000000..0aeb0c84 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_generic_serialization.py @@ -0,0 +1,64 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dataclasses import dataclass +from typing import Any + +from dapr.ext.workflow import ( + ActivityIOAdapter, + CanonicalSerializable, + ensure_canonical_json, + serialize_activity_input, + serialize_activity_output, + use_activity_adapter, +) + + +@dataclass +class _Point(CanonicalSerializable): + x: int + y: int + + def to_canonical_json(self, *, strict: bool = True) -> Any: + return {'x': self.x, 'y': self.y} + + +def test_ensure_canonical_json_on_custom_object(): + p = _Point(1, 2) + out = ensure_canonical_json(p, strict=True) + assert out == {'x': 1, 'y': 2} + + +class _IO(ActivityIOAdapter): + def serialize_input(self, input: Any, *, strict: bool = True) -> Any: + if isinstance(input, _Point): + return {'pt': [input.x, input.y]} + return ensure_canonical_json(input, strict=strict) + + def serialize_output(self, output: Any, *, strict: bool = True) -> Any: + return {'ok': ensure_canonical_json(output, strict=strict)} + + +def test_activity_adapter_decorator_customizes_io(): + _use = use_activity_adapter(_IO()) + + @_use + def act(obj): + return obj + + pt = _Point(3, 4) + inp = serialize_activity_input(act, pt, strict=True) + assert inp == {'pt': [3, 4]} + + out = serialize_activity_output(act, {'k': 'v'}, strict=True) + assert out == {'ok': {'k': 'v'}} diff --git a/ext/dapr-ext-workflow/tests/test_inbound_interceptors.py b/ext/dapr-ext-workflow/tests/test_inbound_interceptors.py new file mode 100644 index 00000000..bc3b28e0 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_inbound_interceptors.py @@ -0,0 +1,559 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest +from dapr.ext.workflow import ( + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + WorkflowRuntime, +) + +from ._fakes import make_act_ctx as _make_act_ctx +from ._fakes import make_orch_ctx as _make_orch_ctx + +""" +Comprehensive inbound interceptor tests for Dapr WorkflowRuntime. + +Tests the current interceptor system for runtime-side workflow and activity execution. +""" + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _TracingInterceptor(RuntimeInterceptor): + """Interceptor that injects and restores trace context.""" + + def __init__(self, events: list[str]): + self.events = events + + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Extract tracing from input + tracing_data = None + if isinstance(request.input, dict) and 'tracing' in request.input: + tracing_data = request.input['tracing'] + self.events.append(f'wf_trace_restored:{tracing_data}') + + # Call next in chain + result = next(request) + + if tracing_data: + self.events.append(f'wf_trace_cleanup:{tracing_data}') + + return result + + def execute_activity(self, request: ExecuteActivityRequest, next): + # Extract tracing from input + tracing_data = None + if isinstance(request.input, dict) and 'tracing' in request.input: + tracing_data = request.input['tracing'] + self.events.append(f'act_trace_restored:{tracing_data}') + + # Call next in chain + result = next(request) + + if tracing_data: + self.events.append(f'act_trace_cleanup:{tracing_data}') + + return result + + +class _LoggingInterceptor(RuntimeInterceptor): + """Interceptor that logs workflow and activity execution.""" + + def __init__(self, events: list[str], label: str): + self.events = events + self.label = label + + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + self.events.append(f'{self.label}:wf_start:{request.input!r}') + try: + result = next(request) + self.events.append(f'{self.label}:wf_complete:{result!r}') + return result + except Exception as e: + self.events.append(f'{self.label}:wf_error:{type(e).__name__}') + raise + + def execute_activity(self, request: ExecuteActivityRequest, next): + self.events.append(f'{self.label}:act_start:{request.input!r}') + try: + result = next(request) + self.events.append(f'{self.label}:act_complete:{result!r}') + return result + except Exception as e: + self.events.append(f'{self.label}:act_error:{type(e).__name__}') + raise + + +class _ValidationInterceptor(RuntimeInterceptor): + """Interceptor that validates inputs and outputs.""" + + def __init__(self, events: list[str]): + self.events = events + + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Validate input + if isinstance(request.input, dict) and request.input.get('invalid'): + self.events.append('wf_validation_failed') + raise ValueError('Invalid workflow input') + + self.events.append('wf_validation_passed') + result = next(request) + + # Validate output + if isinstance(result, dict) and result.get('invalid_output'): + self.events.append('wf_output_validation_failed') + raise ValueError('Invalid workflow output') + + self.events.append('wf_output_validation_passed') + return result + + def execute_activity(self, request: ExecuteActivityRequest, next): + # Validate input + if isinstance(request.input, dict) and request.input.get('invalid'): + self.events.append('act_validation_failed') + raise ValueError('Invalid activity input') + + self.events.append('act_validation_passed') + result = next(request) + + # Validate output + if isinstance(result, str) and 'invalid' in result: + self.events.append('act_output_validation_failed') + raise ValueError('Invalid activity output') + + self.events.append('act_output_validation_passed') + return result + + +def test_single_interceptor_workflow_execution(monkeypatch): + """Test single interceptor around workflow execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='simple') + def simple(ctx, x: int): + return x * 2 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['simple'] + result = orch(_make_orch_ctx(), 5) + + # For non-generator workflows, the result is returned directly + assert result == 10 + assert events == [ + 'log:wf_start:5', + 'log:wf_complete:10', + ] + + +def test_single_interceptor_activity_execution(monkeypatch): + """Test single interceptor around activity execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.activity(name='double') + def double(ctx, x: int) -> int: + return x * 2 + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['double'] + result = act(_make_act_ctx(), 7) + + assert result == 14 + assert events == [ + 'log:act_start:7', + 'log:act_complete:14', + ] + + +def test_multiple_interceptors_execution_order(monkeypatch): + """Test multiple interceptors execute in correct order.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + outer_interceptor = _LoggingInterceptor(events, 'outer') + inner_interceptor = _LoggingInterceptor(events, 'inner') + + # First interceptor in list is outermost + rt = WorkflowRuntime(runtime_interceptors=[outer_interceptor, inner_interceptor]) + + @rt.workflow(name='ordered') + def ordered(ctx, x: int): + return x + 1 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['ordered'] + result = orch(_make_orch_ctx(), 3) + + assert result == 4 + # Outer interceptor enters first, exits last (stack semantics) + assert events == [ + 'outer:wf_start:3', + 'inner:wf_start:3', + 'inner:wf_complete:4', + 'outer:wf_complete:4', + ] + + +def test_tracing_interceptor_context_restoration(monkeypatch): + """Test tracing interceptor properly handles trace context.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + tracing_interceptor = _TracingInterceptor(events) + rt = WorkflowRuntime(runtime_interceptors=[tracing_interceptor]) + + @rt.workflow(name='traced') + def traced(ctx, input_data): + # Workflow can access the trace context that was restored + return {'result': input_data.get('value', 0) * 2} + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['traced'] + + # Input with tracing data + input_with_trace = {'value': 5, 'tracing': {'trace_id': 'abc123', 'span_id': 'def456'}} + + result = orch(_make_orch_ctx(), input_with_trace) + + assert result == {'result': 10} + assert events == [ + "wf_trace_restored:{'trace_id': 'abc123', 'span_id': 'def456'}", + "wf_trace_cleanup:{'trace_id': 'abc123', 'span_id': 'def456'}", + ] + + +def test_validation_interceptor_input_validation(monkeypatch): + """Test validation interceptor catches invalid inputs.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + validation_interceptor = _ValidationInterceptor(events) + rt = WorkflowRuntime(runtime_interceptors=[validation_interceptor]) + + @rt.workflow(name='validated') + def validated(ctx, input_data): + return {'result': 'ok'} + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['validated'] + + # Test valid input + result = orch(_make_orch_ctx(), {'value': 5}) + + assert result == {'result': 'ok'} + assert 'wf_validation_passed' in events + assert 'wf_output_validation_passed' in events + + # Test invalid input + events.clear() + + with pytest.raises(ValueError, match='Invalid workflow input'): + orch(_make_orch_ctx(), {'invalid': True}) + + assert 'wf_validation_failed' in events + + +def test_interceptor_error_handling_workflow(monkeypatch): + """Test interceptor properly handles workflow errors.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='error_wf') + def error_wf(ctx, x: int): + raise ValueError('workflow error') + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['error_wf'] + + with pytest.raises(ValueError, match='workflow error'): + orch(_make_orch_ctx(), 1) + + assert events == [ + 'log:wf_start:1', + 'log:wf_error:ValueError', + ] + + +def test_interceptor_error_handling_activity(monkeypatch): + """Test interceptor properly handles activity errors.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.activity(name='error_act') + def error_act(ctx, x: int) -> int: + raise RuntimeError('activity error') + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['error_act'] + + with pytest.raises(RuntimeError, match='activity error'): + act(_make_act_ctx(), 5) + + assert events == [ + 'log:act_start:5', + 'log:act_error:RuntimeError', + ] + + +def test_async_workflow_with_interceptors(monkeypatch): + """Test interceptors work with async workflows.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='async_wf') + async def async_wf(ctx, x: int): + # Simple async workflow + return x * 3 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['async_wf'] + gen_result = orch(_make_orch_ctx(), 4) + + # Async workflows return a generator that needs to be driven + with pytest.raises(StopIteration) as stop: + next(gen_result) + result = stop.value.value + + assert result == 12 + # The interceptor sees the generator being returned, not the final result + assert events[0] == 'log:wf_start:4' + assert 'log:wf_complete:' in events[1] # The generator object is logged + + +def test_async_activity_with_interceptors(monkeypatch): + """Test interceptors work with async activities.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.activity(name='async_act') + async def async_act(ctx, x: int) -> int: + await asyncio.sleep(0) # Simulate async work + return x * 4 + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['async_act'] + result = act(_make_act_ctx(), 3) + + assert result == 12 + assert events == [ + 'log:act_start:3', + 'log:act_complete:12', + ] + + +def test_generator_workflow_with_interceptors(monkeypatch): + """Test interceptors work with generator workflows.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='gen_wf') + def gen_wf(ctx, x: int): + v1 = yield 'step1' + v2 = yield 'step2' + return (x, v1, v2) + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['gen_wf'] + gen_orch = orch(_make_orch_ctx(), 1) + + # Drive the generator + assert next(gen_orch) == 'step1' + assert gen_orch.send('result1') == 'step2' + with pytest.raises(StopIteration) as stop: + gen_orch.send('result2') + result = stop.value.value + + assert result == (1, 'result1', 'result2') + # For generator workflows, interceptor sees the generator being returned + assert events[0] == 'log:wf_start:1' + assert 'log:wf_complete:' in events[1] # The generator object is logged + + +def test_interceptor_chain_with_early_return(monkeypatch): + """Test interceptor can modify or short-circuit execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + + class _ShortCircuitInterceptor(RuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + events.append('short_circuit_check') + if isinstance(request.input, dict) and request.input.get('short_circuit'): + events.append('short_circuited') + return 'short_circuit_result' + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): + return next(request) + + logging_interceptor = _LoggingInterceptor(events, 'log') + short_circuit_interceptor = _ShortCircuitInterceptor() + + rt = WorkflowRuntime(runtime_interceptors=[short_circuit_interceptor, logging_interceptor]) + + @rt.workflow(name='maybe_short') + def maybe_short(ctx, input_data): + return 'normal_result' + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['maybe_short'] + + # Test normal execution + result = orch(_make_orch_ctx(), {'value': 5}) + + assert result == 'normal_result' + assert 'short_circuit_check' in events + assert 'log:wf_start' in str(events) + assert 'log:wf_complete' in str(events) + + # Test short-circuit execution + events.clear() + result = orch(_make_orch_ctx(), {'short_circuit': True}) + + assert result == 'short_circuit_result' + assert 'short_circuit_check' in events + assert 'short_circuited' in events + # Logging interceptor should not be called when short-circuited + assert 'log:wf_start' not in str(events) + + +def test_interceptor_input_transformation(monkeypatch): + """Test interceptor can transform inputs before execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + + class _TransformInterceptor(RuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Transform input by adding metadata + if isinstance(request.input, dict): + transformed_input = {**request.input, 'interceptor_metadata': 'added'} + new_input = ExecuteWorkflowRequest(ctx=request.ctx, input=transformed_input) + events.append(f'transformed_input:{transformed_input}') + return next(new_input) + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): + return next(request) + + transform_interceptor = _TransformInterceptor() + rt = WorkflowRuntime(runtime_interceptors=[transform_interceptor]) + + @rt.workflow(name='transform_test') + def transform_test(ctx, input_data): + # Workflow should see the transformed input + return input_data + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['transform_test'] + result = orch(_make_orch_ctx(), {'original': 'value'}) + + # Result should include the interceptor metadata + assert result == {'original': 'value', 'interceptor_metadata': 'added'} + assert 'transformed_input:' in str(events) + + +def test_runtime_interceptor_can_shape_activity_result(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + class _ShapeResult(RuntimeInterceptor): + def execute_activity(self, request, next): # type: ignore[override] + res = next(request) + return {'wrapped': res} + + rt = WorkflowRuntime(runtime_interceptors=[_ShapeResult()]) + + @rt.activity(name='echo') + def echo(_ctx, x): + return x + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['echo'] + out = act(_make_act_ctx(), 7) + assert out == {'wrapped': 7} diff --git a/ext/dapr-ext-workflow/tests/test_interceptors.py b/ext/dapr-ext-workflow/tests/test_interceptors.py new file mode 100644 index 00000000..9ba37287 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_interceptors.py @@ -0,0 +1,176 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from dapr.ext.workflow import RuntimeInterceptor, WorkflowRuntime + +from ._fakes import make_act_ctx as _make_act_ctx +from ._fakes import make_orch_ctx as _make_orch_ctx + +""" +Comprehensive inbound interceptor tests for Dapr WorkflowRuntime. + +Tests the current interceptor system for runtime-side workflow and activity execution. +""" + + +""" +Runtime interceptor chain tests for `WorkflowRuntime`. + +This suite intentionally uses a fake worker/registry to validate interceptor composition +without requiring a sidecar. It focuses on the "why" behind runtime interceptors: + +- Ensure `execute_workflow` and `execute_activity` hooks compose in order and are + invoked exactly once around workflow entry/activity execution. +- Cover both generator-based and async workflows, asserting the chain returns a + generator to the runtime (rather than iterating it), preserving send()/throw() + semantics during orchestration replay. +- Keep signal-to-noise high for failures in chain logic independent of gRPC/sidecar. + +These tests complement outbound/client interceptor tests and e2e tests by providing +fast, deterministic coverage of the chaining behavior and generator handling rules. +""" + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _RecorderInterceptor(RuntimeInterceptor): + def __init__(self, events: list[str], label: str): + self.events = events + self.label = label + + def execute_workflow(self, request, next): # type: ignore[override] + self.events.append(f'{self.label}:wf_enter:{request.input!r}') + ret = next(request) + self.events.append(f'{self.label}:wf_ret_type:{ret.__class__.__name__}') + return ret + + def execute_activity(self, request, next): # type: ignore[override] + self.events.append(f'{self.label}:act_enter:{request.input!r}') + res = next(request) + self.events.append(f'{self.label}:act_exit:{res!r}') + return res + + +def test_generator_workflow_hooks_sequence(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + ic = _RecorderInterceptor(events, 'mw') + rt = WorkflowRuntime(runtime_interceptors=[ic]) + + @rt.workflow(name='gen') + def gen(ctx, x: int): + v = yield 'A' + v2 = yield 'B' + return (x, v, v2) + + # Drive the registered orchestrator + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['gen'] + gen_driver = orch(_make_orch_ctx(), 10) + # Prime and run + assert next(gen_driver) == 'A' + assert gen_driver.send('ra') == 'B' + with pytest.raises(StopIteration) as stop: + gen_driver.send('rb') + result = stop.value.value + + assert result == (10, 'ra', 'rb') + # Interceptors run once around the workflow entry; they return a generator to the runtime + assert events[0] == 'mw:wf_enter:10' + assert events[1].startswith('mw:wf_ret_type:') + + +def test_async_workflow_hooks_called(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + ic = _RecorderInterceptor(events, 'mw') + rt = WorkflowRuntime(runtime_interceptors=[ic]) + + @rt.workflow(name='awf') + async def awf(ctx, x: int): + # No awaits to keep the driver simple + return x + 1 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['awf'] + gen_orch = orch(_make_orch_ctx(), 41) + with pytest.raises(StopIteration) as stop: + next(gen_orch) + result = stop.value.value + + assert result == 42 + # For async workflow, interceptor sees entry and a generator return type + assert events[0] == 'mw:wf_enter:41' + assert events[1].startswith('mw:wf_ret_type:') + + +def test_activity_hooks_and_policy(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + + class _ExplodingActivity(RuntimeInterceptor): + def execute_activity(self, request, next): # type: ignore[override] + raise RuntimeError('boom') + + def execute_workflow(self, request, next): # type: ignore[override] + return next(request) + + # Continue-on-error policy + rt = WorkflowRuntime( + runtime_interceptors=[_RecorderInterceptor(events, 'mw'), _ExplodingActivity()] + ) + + @rt.activity(name='double') + def double(ctx, x: int) -> int: + return x * 2 + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['double'] + # Error in interceptor bubbles up + with pytest.raises(RuntimeError): + act(_make_act_ctx(), 5) diff --git a/ext/dapr-ext-workflow/tests/test_metadata_context.py b/ext/dapr-ext-workflow/tests/test_metadata_context.py new file mode 100644 index 00000000..8a1f86fd --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_metadata_context.py @@ -0,0 +1,371 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Optional + +import pytest +from dapr.ext.workflow import ( + ClientInterceptor, + DaprWorkflowClient, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + ScheduleWorkflowRequest, + WorkflowOutboundInterceptor, + WorkflowRuntime, +) + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'id' + self.current_utc_datetime = datetime(2024, 1, 1) + self._custom_status = None + self.is_replaying = False + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + + def call_activity(self, activity, *, input=None, retry_policy=None, app_id=None): + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def call_sub_orchestrator( + self, wf, *, input=None, instance_id=None, retry_policy=None, app_id=None + ): + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def set_custom_status(self, custom_status): + self._custom_status = custom_status + + def create_timer(self, fire_at): + class _T: + def __init__(self, v): + self._v = v + + return _T(fire_at) + + def wait_for_external_event(self, name: str): + class _T: + def __init__(self, v): + self._v = v + + return _T(name) + + +def _drive(gen, returned): + try: + t = gen.send(None) + assert hasattr(t, '_v') + res = returned + while True: + t = gen.send(res) + assert hasattr(t, '_v') + except StopIteration as stop: + return stop.value + + +def test_client_schedule_metadata_envelope(monkeypatch): + import durabletask.client as client_mod + + captured: dict[str, Any] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + def schedule_new_orchestration( + self, + name, + *, + input=None, + instance_id=None, + start_at: Optional[datetime] = None, + reuse_id_policy=None, + ): # noqa: E501 + captured['name'] = name + captured['input'] = input + captured['instance_id'] = instance_id + captured['start_at'] = start_at + captured['reuse_id_policy'] = reuse_id_policy + return 'id-1' + + monkeypatch.setattr(client_mod, 'TaskHubGrpcClient', _FakeClient) + + class _InjectMetadata(ClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): # type: ignore[override] + # Add metadata without touching args + md = {'otel.trace_id': 't-123'} + new_request = ScheduleWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + metadata=md, + ) + return next(new_request) + + client = DaprWorkflowClient(interceptors=[_InjectMetadata()]) + + def wf(ctx, x): + yield 'noop' + + wf.__name__ = 'meta_wf' + instance_id = client.schedule_new_workflow(wf, input={'a': 1}) + assert instance_id == 'id-1' + env = captured['input'] + assert isinstance(env, dict) + assert '__dapr_meta__' in env and '__dapr_payload__' in env + assert env['__dapr_payload__'] == {'a': 1} + assert env['__dapr_meta__']['metadata']['otel.trace_id'] == 't-123' + + +def test_runtime_inbound_unwrap_and_metadata_visible(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + seen: dict[str, Any] = {} + + class _Recorder(RuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + seen['metadata'] = request.metadata + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): # type: ignore[override] + seen['act_metadata'] = request.metadata + return next(request) + + rt = WorkflowRuntime(runtime_interceptors=[_Recorder()]) + + @rt.workflow(name='unwrap') + def unwrap(ctx, x): + # x should be the original payload, not the envelope + assert x == {'hello': 'world'} + return 'ok' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['unwrap'] + envelope = { + '__dapr_meta__': {'v': 1, 'metadata': {'c': 'd'}}, + '__dapr_payload__': {'hello': 'world'}, + } + result = orch(_FakeOrchCtx(), envelope) + assert result == 'ok' + assert seen['metadata'] == {'c': 'd'} + + +def test_outbound_activity_and_child_wrap_metadata(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + class _AddActMeta(WorkflowOutboundInterceptor): + def call_activity(self, request, next): # type: ignore[override] + # Wrap returned args with metadata by returning a new CallActivityRequest + return next( + type(request)( + activity_name=request.activity_name, + input=request.input, + retry_policy=request.retry_policy, + workflow_ctx=request.workflow_ctx, + metadata={'k': 'v'}, + ) + ) + + def call_child_workflow(self, request, next): # type: ignore[override] + return next( + type(request)( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + workflow_ctx=request.workflow_ctx, + metadata={'p': 'q'}, + ) + ) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_AddActMeta()]) + + @rt.workflow(name='parent') + def parent(ctx, x): + a = yield ctx.call_activity(lambda: None, input={'i': 1}) + b = yield ctx.call_child_workflow(lambda c, y: None, input={'j': 2}) + # Return both so we can assert envelopes surfaced through our fake driver + return a, b + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['parent'] + gen = orch(_FakeOrchCtx(), 0) + # First yield: activity token received by driver; shape may be envelope or raw depending on adapter + t1 = gen.send(None) + assert hasattr(t1, '_v') + # Resume with any value; our fake driver ignores and loops + t2 = gen.send({'act': 'done'}) + assert hasattr(t2, '_v') + with pytest.raises(StopIteration) as stop: + gen.send({'child': 'done'}) + result = stop.value.value + # The result is whatever user returned; envelopes validated above + assert isinstance(result, tuple) and len(result) == 2 + + +def test_context_set_metadata_default_propagation(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + # No outbound interceptor needed; runtime will wrap using ctx.get_metadata() + rt = WorkflowRuntime() + + @rt.workflow(name='use_ctx_md') + def use_ctx_md(ctx, x): + # Set default metadata on context + ctx.set_metadata({'k': 'ctx'}) + env = yield ctx.call_activity(lambda: None, input={'p': 1}) + # Return the raw yielded value for assertion + return env + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['use_ctx_md'] + gen = orch(_FakeOrchCtx(), 0) + yielded = gen.send(None) + assert hasattr(yielded, '_v') + env = yielded._v + assert isinstance(env, dict) + assert env.get('__dapr_meta__', {}).get('metadata', {}).get('k') == 'ctx' + + +def test_per_call_metadata_overrides_context(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.workflow(name='override_ctx_md') + def override_ctx_md(ctx, x): + ctx.set_metadata({'k': 'ctx'}) + env = yield ctx.call_activity(lambda: None, input={'p': 1}, metadata={'k': 'per'}) + return env + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['override_ctx_md'] + gen = orch(_FakeOrchCtx(), 0) + yielded = gen.send(None) + env = yielded._v + assert isinstance(env, dict) + assert env.get('__dapr_meta__', {}).get('metadata', {}).get('k') == 'per' + + +def test_execution_info_workflow_and_activity(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + def act(ctx, x): + # activity inbound metadata and execution info available + md = ctx.get_metadata() + ei = ctx.execution_info + assert md == {'m': 'v'} + assert ei is not None and ei.inbound_metadata == {'m': 'v'} + # activity_name should reflect the registered name + assert ei.activity_name == 'act' + return x + + @rt.workflow(name='execinfo') + def execinfo(ctx, x): + # set default metadata + ctx.set_metadata({'m': 'v'}) + # workflow execution info available (minimal inbound only) + wi = ctx.execution_info + assert wi is not None and wi.inbound_metadata == {} + v = yield ctx.call_activity(act, input=42) + return v + + # register activity + rt.activity(name='act')(act) + orch = rt._WorkflowRuntime__worker._registry.orchestrators['execinfo'] + gen = orch(_FakeOrchCtx(), 7) + # drive one yield (call_activity) + gen.send(None) + # send back a value for activity result + with pytest.raises(StopIteration) as stop: + gen.send(42) + assert stop.value.value == 42 + + +def test_client_interceptor_can_shape_schedule_response(monkeypatch): + import durabletask.client as client_mod + + captured: dict[str, Any] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + def schedule_new_orchestration( + self, name, *, input=None, instance_id=None, start_at=None, reuse_id_policy=None + ): + captured['name'] = name + return 'raw-id-123' + + monkeypatch.setattr(client_mod, 'TaskHubGrpcClient', _FakeClient) + + class _ShapeId(ClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): # type: ignore[override] + rid = next(request) + return f'shaped:{rid}' + + client = DaprWorkflowClient(interceptors=[_ShapeId()]) + + def wf(ctx): + yield 'noop' + + wf.__name__ = 'shape_test' + iid = client.schedule_new_workflow(wf, input=None) + assert iid == 'shaped:raw-id-123' diff --git a/ext/dapr-ext-workflow/tests/test_outbound_interceptors.py b/ext/dapr-ext-workflow/tests/test_outbound_interceptors.py new file mode 100644 index 00000000..f7f1c1cb --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_outbound_interceptors.py @@ -0,0 +1,203 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from dapr.ext.workflow import ( + BaseWorkflowOutboundInterceptor, + WorkflowOutboundInterceptor, + WorkflowRuntime, +) + + +class _FakeRegistry: + def __init__(self): + self.orchestrators = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'id' + self.current_utc_datetime = __import__('datetime').datetime(2024, 1, 1) + self.is_replaying = False + self._custom_status = None + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + self._continued_payload = None + self.workflow_attempt = None + + def call_activity(self, activity, *, input=None, retry_policy=None, app_id=None): + # return input back for assertion through driver + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def call_sub_orchestrator( + self, wf, *, input=None, instance_id=None, retry_policy=None, app_id=None + ): + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def set_custom_status(self, custom_status): + self._custom_status = custom_status + + def create_timer(self, fire_at): + class _T: + def __init__(self, v): + self._v = v + + return _T(fire_at) + + def wait_for_external_event(self, name: str): + class _T: + def __init__(self, v): + self._v = v + + return _T(name) + + def continue_as_new(self, new_request, *, save_events: bool = False): + # Record payload for assertions + self._continued_payload = new_request + + +def drive(gen, returned): + try: + t = gen.send(None) + assert hasattr(t, '_v') + res = returned + while True: + t = gen.send(res) + assert hasattr(t, '_v') + except StopIteration as stop: + return stop.value + + +class _InjectTrace(WorkflowOutboundInterceptor): + def call_activity(self, request, next): # type: ignore[override] + x = request.input + if x is None: + request = type(request)( + activity_name=request.activity_name, + input={'tracing': 'T'}, + retry_policy=request.retry_policy, + ) + elif isinstance(x, dict): + out = dict(x) + out.setdefault('tracing', 'T') + request = type(request)( + activity_name=request.activity_name, input=out, retry_policy=request.retry_policy + ) + return next(request) + + def call_child_workflow(self, request, next): # type: ignore[override] + return next( + type(request)( + workflow_name=request.workflow_name, + input={'child': request.input}, + instance_id=request.instance_id, + ) + ) + + +def test_outbound_activity_injection(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_InjectTrace()]) + + @rt.workflow(name='w') + def w(ctx, x): + # schedule an activity; runtime should pass transformed input to durable task + y = yield ctx.call_activity(lambda: None, input={'a': 1}) + return y['tracing'] + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w'] + gen = orch(_FakeOrchCtx(), 0) + out = drive(gen, returned={'tracing': 'T', 'a': 1}) + assert out == 'T' + + +def test_outbound_child_injection(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_InjectTrace()]) + + def child(ctx, x): + yield 'noop' + + @rt.workflow(name='parent') + def parent(ctx, x): + y = yield ctx.call_child_workflow(child, input={'b': 2}) + return y + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['parent'] + gen = orch(_FakeOrchCtx(), 0) + out = drive(gen, returned={'child': {'b': 2}}) + assert out == {'child': {'b': 2}} + + +def test_outbound_continue_as_new_injection(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + class _InjectCAN(BaseWorkflowOutboundInterceptor): + def continue_as_new(self, request, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault('x', '1') + request.metadata = md + return next(request) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_InjectCAN()]) + + @rt.workflow(name='w2') + def w2(ctx, x): + ctx.continue_as_new({'p': 1}, carryover_metadata=True) + return 'unreached' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w2'] + fake = _FakeOrchCtx() + _ = orch(fake, 0) + # Verify envelope contains injected metadata + assert isinstance(fake._continued_payload, dict) + meta = fake._continued_payload.get('__dapr_meta__') + payload = fake._continued_payload.get('__dapr_payload__') + assert isinstance(meta, dict) and isinstance(payload, dict) + assert meta.get('metadata', {}).get('x') == '1' + assert payload == {'p': 1} diff --git a/ext/dapr-ext-workflow/tests/test_sandbox_gather.py b/ext/dapr-ext-workflow/tests/test_sandbox_gather.py new file mode 100644 index 00000000..7bb509dd --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_sandbox_gather.py @@ -0,0 +1,169 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime, timedelta + +import pytest +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from dapr.ext.workflow.aio.sandbox import sandbox_scope +from durabletask.aio.sandbox import SandboxMode + +""" +Tests for sandboxed asyncio.gather behavior in async orchestrators. +""" + + +class _FakeCtx: + def __init__(self): + self.current_utc_datetime = datetime(2024, 1, 1) + self.instance_id = 'test-instance' + + def create_timer(self, fire_at): + class _T: + def __init__(self): + self._parent = None + self.is_complete = False + + return _T() + + def wait_for_external_event(self, name: str): + class _T: + def __init__(self): + self._parent = None + self.is_complete = False + + return _T() + + +def drive(gen, results): + try: + gen.send(None) + i = 0 + while True: + gen.send(results[i]) + i += 1 + except StopIteration as stop: + return stop.value + + +async def _plain(value): + return value + + +async def awf_empty(ctx: AsyncWorkflowContext): + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + out = await asyncio.gather() + return out + + +def test_sandbox_gather_empty_returns_list(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_empty) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[None]) + assert out == [] + + +async def awf_when_all(ctx: AsyncWorkflowContext): + a = ctx.create_timer(timedelta(seconds=0)) + b = ctx.wait_for_external_event('x') + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + res = await asyncio.gather(a, b) + return res + + +def test_sandbox_gather_all_workflow_maps_to_when_all(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_when_all) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[[1, 2]]) + assert out == [1, 2] + + +async def awf_mixed(ctx: AsyncWorkflowContext): + a = ctx.create_timer(timedelta(seconds=0)) + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + res = await asyncio.gather(a, _plain('ok')) + return res + + +def test_sandbox_gather_mixed_returns_sequential_results(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_mixed) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[123]) + assert out == [123, 'ok'] + + +async def awf_return_exceptions(ctx: AsyncWorkflowContext): + async def _boom(): + raise RuntimeError('x') + + a = ctx.create_timer(timedelta(seconds=0)) + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + res = await asyncio.gather(a, _boom(), return_exceptions=True) + return res + + +def test_sandbox_gather_return_exceptions(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_return_exceptions) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[321]) + assert isinstance(out[1], RuntimeError) + + +async def awf_multi_await(ctx: AsyncWorkflowContext): + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + g = asyncio.gather() + a = await g + b = await g + return (a, b) + + +def test_sandbox_gather_multi_await_safe(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_multi_await) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[None]) + assert out == ([], []) + + +def test_sandbox_gather_restored_outside(): + import asyncio as aio + + original = aio.gather + fake = _FakeCtx() + ctx = AsyncWorkflowContext(fake) + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + pass + # After exit, gather should be restored + assert aio.gather is original + + +def test_strict_mode_blocks_create_task(): + import asyncio as aio + + fake = _FakeCtx() + ctx = AsyncWorkflowContext(fake) + with sandbox_scope(ctx, SandboxMode.STRICT): + if hasattr(aio, 'create_task'): + with pytest.raises(RuntimeError): + # Use a dummy coroutine to trigger the block + async def _c(): + return 1 + + aio.create_task(_c()) diff --git a/ext/dapr-ext-workflow/tests/test_trace_fields.py b/ext/dapr-ext-workflow/tests/test_trace_fields.py new file mode 100644 index 00000000..03d38e1e --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_trace_fields.py @@ -0,0 +1,60 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from dapr.ext.workflow.execution_info import ActivityExecutionInfo, WorkflowExecutionInfo +from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext + + +class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'wf-123' + self.current_utc_datetime = datetime(2025, 1, 1, tzinfo=timezone.utc) + self.is_replaying = False + self.workflow_name = 'wf_name' + self.parent_instance_id = 'parent-1' + self.history_event_sequence = 42 + self.trace_parent = '00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01' + self.trace_state = 'vendor=state' + self.orchestration_span_id = 'bbbbbbbbbbbbbbbb' + + +class _FakeActivityCtx: + def __init__(self): + self.orchestration_id = 'wf-123' + self.task_id = 7 + self.trace_parent = '00-cccccccccccccccccccccccccccccccc-dddddddddddddddd-01' + self.trace_state = 'v=1' + + +def test_workflow_execution_info_minimal(): + ei = WorkflowExecutionInfo(inbound_metadata={'k': 'v'}) + assert ei.inbound_metadata == {'k': 'v'} + + +def test_activity_execution_info_minimal(): + aei = ActivityExecutionInfo(inbound_metadata={'m': 'v'}, activity_name='act_name') + assert aei.inbound_metadata == {'m': 'v'} + + +def test_workflow_activity_context_execution_info_trace_fields(): + base = _FakeActivityCtx() + actx = WorkflowActivityContext(base) + aei = ActivityExecutionInfo(inbound_metadata={}, activity_name='act_name') + actx._set_execution_info(aei) + got = actx.execution_info + assert got is not None + assert got.inbound_metadata == {} diff --git a/ext/dapr-ext-workflow/tests/test_tracing_interceptors.py b/ext/dapr-ext-workflow/tests/test_tracing_interceptors.py new file mode 100644 index 00000000..35f93361 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_tracing_interceptors.py @@ -0,0 +1,171 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import uuid +from datetime import datetime +from typing import Any + +from dapr.ext.workflow import ( + ClientInterceptor, + DaprWorkflowClient, + RuntimeInterceptor, + WorkflowRuntime, +) + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeOrchestrationContext: + def __init__(self, *, is_replaying: bool = False): + self.instance_id = 'wf-1' + self.current_utc_datetime = datetime(2025, 1, 1) + self.is_replaying = is_replaying + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + + +def _drive_generator(gen, returned_value): + # Prime to first yield; then drive + next(gen) + while True: + try: + gen.send(returned_value) + except StopIteration as stop: + return stop.value + + +def test_client_injects_tracing_on_schedule(monkeypatch): + import durabletask.client as client_mod + + # monkeypatch TaskHubGrpcClient to capture inputs + scheduled: dict[str, Any] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + def schedule_new_orchestration( + self, name, *, input=None, instance_id=None, start_at=None, reuse_id_policy=None + ): + scheduled['name'] = name + scheduled['input'] = input + scheduled['instance_id'] = instance_id + scheduled['start_at'] = start_at + scheduled['reuse_id_policy'] = reuse_id_policy + return 'id-1' + + monkeypatch.setattr(client_mod, 'TaskHubGrpcClient', _FakeClient) + + class _TracingClient(ClientInterceptor): + def schedule_new_workflow(self, request, next): # type: ignore[override] + tr = {'trace_id': uuid.uuid4().hex} + if isinstance(request.input, dict) and 'tracing' not in request.input: + request = type(request)( + workflow_name=request.workflow_name, + input={**request.input, 'tracing': tr}, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + ) + return next(request) + + client = DaprWorkflowClient(interceptors=[_TracingClient()]) + + # We only need a callable with a __name__ for scheduling + def wf(ctx): + yield 'noop' + + wf.__name__ = 'inject_test' + instance_id = client.schedule_new_workflow(wf, input={'a': 1}) + assert instance_id == 'id-1' + assert scheduled['name'] == 'inject_test' + assert isinstance(scheduled['input'], dict) + assert 'tracing' in scheduled['input'] + assert scheduled['input']['a'] == 1 + + +def test_runtime_restores_tracing_before_user_code(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + seen: dict[str, Any] = {} + + class _TracingRuntime(RuntimeInterceptor): + def execute_workflow(self, request, next): # type: ignore[override] + # no-op; real restoration is app concern; test just ensures input contains tracing + return next(request) + + def execute_activity(self, request, next): # type: ignore[override] + return next(request) + + class _TracingClient2(ClientInterceptor): + def schedule_new_workflow(self, request, next): # type: ignore[override] + tr = {'trace_id': 't1'} + if isinstance(request.input, dict): + request = type(request)( + workflow_name=request.workflow_name, + input={**request.input, 'tracing': tr}, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + ) + return next(request) + + rt = WorkflowRuntime( + runtime_interceptors=[_TracingRuntime()], + ) + + @rt.workflow(name='w') + def w(ctx, x): + # The tracing should already be present in input + assert isinstance(x, dict) + assert 'tracing' in x + seen['trace'] = x['tracing'] + yield 'noop' + return 'ok' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w'] + # Orchestrator input will have tracing injected via outbound when scheduled as a child or via client + # Here, we directly pass the input simulating schedule with tracing present + gen = orch(_FakeOrchestrationContext(), {'hello': 'world', 'tracing': {'trace_id': 't1'}}) + out = _drive_generator(gen, returned_value='noop') + assert out == 'ok' + assert seen['trace']['trace_id'] == 't1' diff --git a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py index bf18cd68..f367a43f 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py @@ -37,7 +37,10 @@ class WorkflowRuntimeTest(unittest.TestCase): def setUp(self): listActivities.clear() listOrchestrators.clear() - mock.patch('durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()).start() + self.patcher = mock.patch( + 'durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker() + ) + self.patcher.start() self.runtime_options = WorkflowRuntime() if hasattr(self.mock_client_wf, '_dapr_alternate_name'): del self.mock_client_wf.__dict__['_dapr_alternate_name'] @@ -48,6 +51,11 @@ def setUp(self): if hasattr(self.mock_client_activity, '_activity_registered'): del self.mock_client_activity.__dict__['_activity_registered'] + def tearDown(self): + """Stop the mock patch to prevent interference with other tests.""" + self.patcher.stop() + mock.patch.stopall() # Ensure all patches are stopped + def mock_client_wf(ctx: DaprWorkflowContext, input): print(f'{input}') diff --git a/ext/dapr-ext-workflow/tests/test_workflow_util.py b/ext/dapr-ext-workflow/tests/test_workflow_util.py index 28e92e6c..c1b980ed 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_util.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_util.py @@ -1,3 +1,16 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import unittest from unittest.mock import patch @@ -7,6 +20,7 @@ class DaprWorkflowUtilTest(unittest.TestCase): + @patch.object(settings, 'DAPR_GRPC_ENDPOINT', '') def test_get_address_default(self): expected = f'{settings.DAPR_RUNTIME_HOST}:{settings.DAPR_GRPC_PORT}' self.assertEqual(expected, getAddress())