diff --git a/Makefile b/Makefile index 69daa40..3a387b0 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,13 @@ test-unit: test-e2e: pytest -m e2e --verbose +coverage-clean: + rm -f .coverage .coverage.* coverage.xml + +coverage-all: coverage-clean + pytest -m "not e2e" --durations=0 --cov=durabletask --cov-branch --cov-report=term-missing --cov-report=xml + pytest -m e2e --durations=0 --cov=durabletask --cov-branch --cov-report=term-missing --cov-report=xml --cov-append + install: python3 -m pip install . @@ -18,4 +25,4 @@ gen-proto: python3 -m grpc_tools.protoc --proto_path=. --python_out=. --pyi_out=. --grpc_python_out=. ./durabletask/internal/orchestrator_service.proto rm durabletask/internal/*.proto -.PHONY: init test-unit test-e2e gen-proto install +.PHONY: init test-unit test-e2e coverage-clean coverage-unit coverage-e2e coverage-all gen-proto install diff --git a/README.md b/README.md index f6a0284..1cc387e 100644 --- a/README.md +++ b/README.md @@ -126,10 +126,62 @@ Orchestrations can be continued as new using the `continue_as_new` API. This API Orchestrations can be suspended using the `suspend_orchestration` client API and will remain suspended until resumed using the `resume_orchestration` client API. A suspended orchestration will stop processing new events, but will continue to buffer any that happen to arrive until resumed, ensuring that no data is lost. An orchestration can also be terminated using the `terminate_orchestration` client API. Terminated orchestrations will stop processing new events and will discard any buffered events. -### Retry policies (TODO) +### Retry policies Orchestrations can specify retry policies for activities and sub-orchestrations. These policies control how many times and how frequently an activity or sub-orchestration will be retried in the event of a transient error. +#### Creating a retry policy + +```python +from datetime import timedelta +from durabletask import task + +retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), # Initial delay before first retry + max_number_of_attempts=5, # Maximum total attempts (includes first attempt) + backoff_coefficient=2.0, # Exponential backoff multiplier (must be >= 1) + max_retry_interval=timedelta(seconds=30), # Cap on retry delay + retry_timeout=timedelta(minutes=5), # Total time limit for all retries (optional) +) +``` + +**Notes:** +- `max_number_of_attempts` **includes the initial attempt**. For example, `max_number_of_attempts=5` means 1 initial attempt + up to 4 retries. +- `retry_timeout` is optional. If omitted or set to `None`, retries continue until `max_number_of_attempts` is reached. +- `backoff_coefficient` controls exponential backoff: delay = `first_retry_interval * (backoff_coefficient ^ retry_number)`, capped by `max_retry_interval`. +- `non_retryable_error_types` (optional) can specify additional exception types to treat as non-retryable (e.g., `[ValueError, TypeError]`). `NonRetryableError` is always non-retryable regardless of this setting. + +#### Using retry policies + +Apply retry policies to activities or sub-orchestrations: + +```python +def my_orchestrator(ctx: task.OrchestrationContext, input): + # Retry an activity + result = yield ctx.call_activity(my_activity, input=data, retry_policy=retry_policy) + + # Retry a sub-orchestration + result = yield ctx.call_sub_orchestrator(child_orchestrator, input=data, retry_policy=retry_policy) +``` + +#### Non-retryable errors + +For errors that should not be retried (e.g., validation failures, permanent errors), raise a `NonRetryableError`: + +```python +from durabletask.task import NonRetryableError + +def my_activity(ctx: task.ActivityContext, input): + if input is None: + # This error will bypass retry logic and fail immediately + raise NonRetryableError("Input cannot be None") + + # Transient errors (network, timeouts, etc.) will be retried + return call_external_service(input) +``` + +Even with a retry policy configured, `NonRetryableError` will fail immediately without retrying. + ## Getting Started ### Prerequisites @@ -194,7 +246,7 @@ Certain aspects like multi-app activities require the full dapr runtime to be ru ```shell dapr init || true -dapr run --app-id test-app --dapr-grpc-port 4001 --components-path ./examples/components/ +dapr run --app-id test-app --dapr-grpc-port 4001 --resources-path ./examples/components/ ``` To run the E2E tests on a specific python version (eg: 3.11), run the following command from the project root: diff --git a/dev-requirements.txt b/dev-requirements.txt index ba589ab..e69de29 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1 +0,0 @@ -grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python # supports protobuf 6.x and aligns with generated code \ No newline at end of file diff --git a/durabletask/client.py b/durabletask/client.py index 1e28f30..92466ec 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -127,9 +127,28 @@ def __init__( interceptors=interceptors, options=channel_options, ) + self._channel = channel self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + try: + self.close() + finally: + return False + + def close(self) -> None: + """Close the underlying gRPC channel.""" + try: + # grpc.Channel.close() is idempotent + self._channel.close() + except Exception: + # Best-effort cleanup + pass + def schedule_new_orchestration( self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], @@ -188,10 +207,59 @@ def wait_for_orchestration_completion( ) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: - grpc_timeout = None if timeout == 0 else timeout - self._logger.info( - f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete." - ) + # gRPC timeout mapping (pytest unit tests may pass None explicitly) + grpc_timeout = None if (timeout is None or timeout == 0) else timeout + + # If timeout is None or 0, skip pre-checks/polling and call server-side wait directly + if timeout is None or timeout == 0: + self._logger.info( + f"Waiting {'indefinitely' if not timeout else f'up to {timeout}s'} for instance '{instance_id}' to complete." + ) + res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion( + req, timeout=grpc_timeout + ) + state = new_orchestration_state(req.instanceId, res) + return state + + # For positive timeout, best-effort pre-check and short polling to avoid long server waits + try: + # First check if the orchestration is already completed + current_state = self.get_orchestration_state( + instance_id, fetch_payloads=fetch_payloads + ) + if current_state and current_state.runtime_status in [ + OrchestrationStatus.COMPLETED, + OrchestrationStatus.FAILED, + OrchestrationStatus.TERMINATED, + ]: + return current_state + + # Poll for completion with exponential backoff to handle eventual consistency + import time + + poll_timeout = min(timeout, 10) + poll_start = time.time() + poll_interval = 0.1 + + while time.time() - poll_start < poll_timeout: + current_state = self.get_orchestration_state( + instance_id, fetch_payloads=fetch_payloads + ) + + if current_state and current_state.runtime_status in [ + OrchestrationStatus.COMPLETED, + OrchestrationStatus.FAILED, + OrchestrationStatus.TERMINATED, + ]: + return current_state + + time.sleep(poll_interval) + poll_interval = min(poll_interval * 1.5, 1.0) # Exponential backoff, max 1s + except Exception: + # Ignore pre-check/poll issues (e.g., mocked stubs in unit tests) and fall back + pass + + self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to complete.") res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion( req, timeout=grpc_timeout ) diff --git a/durabletask/deterministic.py b/durabletask/deterministic.py new file mode 100644 index 0000000..2943783 --- /dev/null +++ b/durabletask/deterministic.py @@ -0,0 +1,224 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Deterministic utilities for Durable Task workflows (async and generator). + +This module provides deterministic alternatives to non-deterministic Python +functions, ensuring workflow replay consistency across different executions. +It is shared by both the asyncio authoring model and the generator-based model. +""" + +import hashlib +import random +import string as _string +import uuid +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Optional, TypeVar + + +@dataclass +class DeterminismSeed: + """Seed data for deterministic operations.""" + + instance_id: str + orchestration_unix_ts: int + + def to_int(self) -> int: + """Convert seed to integer for PRNG initialization.""" + combined = f"{self.instance_id}:{self.orchestration_unix_ts}" + hash_bytes = hashlib.sha256(combined.encode("utf-8")).digest() + return int.from_bytes(hash_bytes[:8], byteorder="big") + + +def derive_seed(instance_id: str, orchestration_time: datetime) -> int: + """ + Derive a deterministic seed from instance ID and orchestration time. + """ + ts = int(orchestration_time.timestamp()) + return DeterminismSeed(instance_id=instance_id, orchestration_unix_ts=ts).to_int() + + +def deterministic_random(instance_id: str, orchestration_time: datetime) -> random.Random: + """ + Create a deterministic random number generator. + """ + seed = derive_seed(instance_id, orchestration_time) + return random.Random(seed) + + +def deterministic_uuid4(rnd: random.Random) -> uuid.UUID: + """ + Generate a deterministic UUID4 using the provided random generator. + + Note: This is deprecated in favor of deterministic_uuid_v5 which matches + the .NET implementation for cross-language compatibility. + """ + bytes_ = bytes(rnd.randrange(0, 256) for _ in range(16)) + bytes_list = list(bytes_) + bytes_list[6] = (bytes_list[6] & 0x0F) | 0x40 # Version 4 + bytes_list[8] = (bytes_list[8] & 0x3F) | 0x80 # Variant bits + return uuid.UUID(bytes=bytes(bytes_list)) + + +def deterministic_uuid_v5(instance_id: str, current_datetime: datetime, counter: int) -> uuid.UUID: + """ + Generate a deterministic UUID v5 matching the .NET implementation. + + This implementation matches the durabletask-dotnet NewGuid() method: + https://github.com/microsoft/durabletask-dotnet/blob/main/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs + + Args: + instance_id: The orchestration instance ID. + current_datetime: The current orchestration datetime (frozen during replay). + counter: The per-call counter (starts at 0 on each replay). + + Returns: + A deterministic UUID v5 that will be the same across replays. + """ + # DNS namespace UUID - same as .NET DnsNamespaceValue + namespace = uuid.UUID("9e952958-5e33-4daf-827f-2fa12937b875") + + # Build name matching .NET format: instanceId_datetime_counter + # Using isoformat() which produces ISO 8601 format similar to .NET's ToString("o") + name = f"{instance_id}_{current_datetime.isoformat()}_{counter}" + + # Generate UUID v5 (SHA-1 based, matching .NET) + return uuid.uuid5(namespace, name) + + +class DeterministicContextMixin: + """ + Mixin providing deterministic helpers for workflow contexts. + + Assumes the inheriting class exposes `instance_id` and `current_utc_datetime` attributes. + + This implementation matches the .NET durabletask SDK approach with an explicit + counter for UUID generation that resets on each replay. + """ + + def __init__(self, *args, **kwargs): + """Initialize the mixin with UUID and timestamp counters.""" + super().__init__(*args, **kwargs) + # Counter for deterministic UUID generation (matches .NET newGuidCounter) + # This counter resets to 0 on each replay, ensuring determinism + self._uuid_counter: int = 0 + # Counter for deterministic timestamp sequencing (resets on replay) + self._timestamp_counter: int = 0 + + def now(self) -> datetime: + """Alias for deterministic current_utc_datetime.""" + return self.current_utc_datetime # type: ignore[attr-defined] + + def random(self) -> random.Random: + """Return a PRNG seeded deterministically from instance id and orchestration time.""" + rnd = deterministic_random( + self.instance_id, # type: ignore[attr-defined] + self.current_utc_datetime, # type: ignore[attr-defined] + ) + # Mark as deterministic for asyncio sandbox detector whitelisting of bound methods (randint, random) + try: + rnd._dt_deterministic = True + except Exception: + pass + return rnd + + def uuid4(self) -> uuid.UUID: + """ + Return a deterministically generated UUID v5 with explicit counter. + https://www.sohamkamani.com/uuid-versions-explained/#v5-non-random-uuids + + This matches the .NET implementation's NewGuid() method which uses: + - Instance ID + - Current UTC datetime (frozen during replay) + - Per-call counter (resets to 0 on each replay) + + The counter ensures multiple calls produce different UUIDs while maintaining + determinism across replays. + """ + # Lazily initialize counter if not set by __init__ (for compatibility) + if not hasattr(self, "_uuid_counter"): + self._uuid_counter = 0 + + result = deterministic_uuid_v5( + self.instance_id, # type: ignore[attr-defined] + self.current_utc_datetime, # type: ignore[attr-defined] + self._uuid_counter, + ) + self._uuid_counter += 1 + return result + + def new_guid(self) -> uuid.UUID: + """Alias for uuid4 for API parity with other SDKs.""" + return self.uuid4() + + def random_string(self, length: int, *, alphabet: Optional[str] = None) -> str: + """Return a deterministically generated random string of the given length.""" + if length < 0: + raise ValueError("length must be non-negative") + chars = alphabet if alphabet is not None else (_string.ascii_letters + _string.digits) + if not chars: + raise ValueError("alphabet must not be empty") + rnd = self.random() + size = len(chars) + return "".join(chars[rnd.randrange(0, size)] for _ in range(length)) + + def random_int(self, min_value: int = 0, max_value: int = 2**31 - 1) -> int: + """Return a deterministic random integer in the specified range.""" + if min_value > max_value: + raise ValueError("min_value must be <= max_value") + rnd = self.random() + return rnd.randint(min_value, max_value) + + T = TypeVar("T") + + def random_choice(self, sequence: Sequence[T]) -> T: + """Return a deterministic random element from a non-empty sequence.""" + if not sequence: + raise IndexError("Cannot choose from empty sequence") + rnd = self.random() + return rnd.choice(sequence) + + def now_with_sequence(self) -> datetime: + """ + Return deterministic timestamp with microsecond increment per call. + + Each call returns: current_utc_datetime + (counter * 1 microsecond) + + This provides ordered, unique timestamps for tracing/telemetry while maintaining + determinism across replays. The counter resets to 0 on each replay (similar to + _uuid_counter pattern). + + Perfect for preserving event ordering within a workflow without requiring activities. + + Returns: + datetime: Deterministic timestamp that increments on each call + + Example: + ```python + def workflow(ctx): + t1 = ctx.now_with_sequence() # 2024-01-01 12:00:00.000000 + result = yield ctx.call_activity(some_activity, input="data") + t2 = ctx.now_with_sequence() # 2024-01-01 12:00:00.000001 + # t1 < t2, preserving order for telemetry + ``` + """ + offset = timedelta(microseconds=self._timestamp_counter) + self._timestamp_counter += 1 + return self.current_utc_datetime + offset # type: ignore[attr-defined] + + def current_utc_datetime_with_sequence(self): + """Alias for now_with_sequence for API parity with other SDKs.""" + return self.now_with_sequence() diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 3adb6b1..4fe6d73 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -102,7 +102,7 @@ def get_logger( # Add a default log handler if none is provided if log_handler is None: log_handler = logging.StreamHandler() - log_handler.setLevel(logging.INFO) + log_handler.setLevel(logging.DEBUG) logger.handlers.append(log_handler) # Set a default log formatter to our handler if none is provided diff --git a/durabletask/task.py b/durabletask/task.py index 66abc28..6626761 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -233,6 +233,16 @@ class OrchestrationStateError(Exception): pass +class NonRetryableError(Exception): + """Exception indicating the operation should not be retried. + + If an activity or sub-orchestration raises this exception, retry logic will be + bypassed and the failure will be returned immediately to the orchestrator. + """ + + pass + + class Task(ABC, Generic[T]): """Abstract base class for asynchronous tasks in a durable orchestration.""" @@ -397,7 +407,7 @@ def compute_next_delay(self) -> Optional[timedelta]: next_delay_f = min( next_delay_f, self._retry_policy.max_retry_interval.total_seconds() ) - return timedelta(seconds=next_delay_f) + return timedelta(seconds=next_delay_f) return None @@ -490,6 +500,7 @@ def __init__( backoff_coefficient: Optional[float] = 1.0, max_retry_interval: Optional[timedelta] = None, retry_timeout: Optional[timedelta] = None, + non_retryable_error_types: Optional[list[Union[str, type]]] = None, ): """Creates a new RetryPolicy instance. @@ -505,6 +516,11 @@ def __init__( The maximum retry interval to use for any retry attempt. retry_timeout : Optional[timedelta] The maximum amount of time to spend retrying the operation. + non_retryable_error_types : Optional[list[Union[str, type]]] + A list of exception type names or classes that should not be retried. + If a failure's error type matches any of these, the task fails immediately. + The built-in NonRetryableError is always treated as non-retryable regardless + of this setting. """ # validate inputs if first_retry_interval < timedelta(seconds=0): @@ -523,6 +539,17 @@ def __init__( self._backoff_coefficient = backoff_coefficient self._max_retry_interval = max_retry_interval self._retry_timeout = retry_timeout + # Normalize non-retryable error type names to a set of strings + names: Optional[set[str]] = None + if non_retryable_error_types: + names = set() + for t in non_retryable_error_types: + if isinstance(t, str): + if t: + names.add(t) + elif isinstance(t, type): + names.add(t.__name__) + self._non_retryable_error_types = names @property def first_retry_interval(self) -> timedelta: @@ -549,6 +576,15 @@ def retry_timeout(self) -> Optional[timedelta]: """The maximum amount of time to spend retrying the operation.""" return self._retry_timeout + @property + def non_retryable_error_types(self) -> Optional[set[str]]: + """Set of error type names that should not be retried. + + Comparison is performed against the errorType string provided by the + backend (typically the exception class name). + """ + return self._non_retryable_error_types + def get_name(fn: Callable) -> str: """Returns the name of the provided function""" diff --git a/durabletask/worker.py b/durabletask/worker.py index daa661b..08d1af9 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -19,7 +19,7 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared -from durabletask import task +from durabletask import deterministic, task from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl TInput = TypeVar("TInput") @@ -159,6 +159,8 @@ class TaskHubGrpcWorker: interceptors to apply to the channel. Defaults to None. concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for controlling worker concurrency limits. If None, default settings are used. + stop_timeout (float, optional): Maximum time in seconds to wait for the worker thread + to stop when calling stop(). Defaults to 30.0. Useful to set lower values in tests. Attributes: concurrency_options (ConcurrencyOptions): The current concurrency configuration. @@ -224,6 +226,7 @@ def __init__( interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, concurrency_options: Optional[ConcurrencyOptions] = None, channel_options: Optional[Sequence[tuple[str, Any]]] = None, + stop_timeout: float = 30.0, ): self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() @@ -232,6 +235,12 @@ def __init__( self._is_running = False self._secure_channel = secure_channel self._channel_options = channel_options + self._stop_timeout = stop_timeout + # Track in-flight activity executions for graceful draining + import threading as _threading + + self._active_task_count = 0 + self._active_task_cv = _threading.Condition() # Use provided concurrency options or create default ones self._concurrency_options = ( @@ -249,6 +258,8 @@ def __init__( self._interceptors = None self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options) + # Readiness flag set once the worker has an active stream to the sidecar + self._ready = Event() @property def concurrency_options(self) -> ConcurrencyOptions: @@ -351,6 +362,8 @@ def invalidate_connection(): pass current_channel = None current_stub = None + # No longer ready if connection is gone + self._ready.clear() def should_invalidate_connection(rpc_error): error_code = rpc_error.code() # type: ignore @@ -390,6 +403,8 @@ def should_invalidate_connection(rpc_error): self._logger.info( f"Successfully connected to {self._host_address}. Waiting for work items..." ) + # Signal readiness once stream is established + self._ready.set() # Use a thread to read from the blocking gRPC stream and forward to asyncio import queue @@ -398,7 +413,10 @@ def should_invalidate_connection(rpc_error): def stream_reader(): try: - for work_item in self._response_stream: + stream = self._response_stream + if stream is None: + return + for work_item in stream: # type: ignore work_item_queue.put(work_item) except Exception as e: work_item_queue.put(e) @@ -409,33 +427,42 @@ def stream_reader(): current_reader_thread.start() loop = asyncio.get_running_loop() while not self._shutdown.is_set(): - try: - work_item = await loop.run_in_executor(None, work_item_queue.get) - if isinstance(work_item, Exception): - raise work_item - request_type = work_item.WhichOneof("request") - self._logger.debug(f'Received "{request_type}" work item') - if work_item.HasField("orchestratorRequest"): - self._async_worker_manager.submit_orchestration( - self._execute_orchestrator, - work_item.orchestratorRequest, - stub, - work_item.completionToken, - ) - elif work_item.HasField("activityRequest"): - self._async_worker_manager.submit_activity( - self._execute_activity, - work_item.activityRequest, - stub, - work_item.completionToken, - ) - elif work_item.HasField("healthPing"): - pass - else: - self._logger.warning(f"Unexpected work item type: {request_type}") - except Exception as e: - self._logger.warning(f"Error in work item stream: {e}") - raise e + work_item = await loop.run_in_executor(None, work_item_queue.get) + if isinstance(work_item, Exception): + raise work_item + request_type = work_item.WhichOneof("request") + self._logger.debug(f'Received "{request_type}" work item') + if work_item.HasField("orchestratorRequest"): + self._async_worker_manager.submit_orchestration( + self._execute_orchestrator, + work_item.orchestratorRequest, + stub, + work_item.completionToken, + ) + elif work_item.HasField("activityRequest"): + # track active tasks for graceful shutdown + with self._active_task_cv: + self._active_task_count += 1 + + def _tracked_execute_activity(req, stub_arg, token): + try: + return self._execute_activity(req, stub_arg, token) + finally: + # decrement active tasks + with self._active_task_cv: + self._active_task_count -= 1 + self._active_task_cv.notify_all() + + self._async_worker_manager.submit_activity( + _tracked_execute_activity, + work_item.activityRequest, + stub, + work_item.completionToken, + ) + elif work_item.HasField("healthPing"): + pass + else: + self._logger.warning(f"Unexpected work item type: {request_type}") current_reader_thread.join(timeout=1) self._logger.info("Work item stream ended normally") except grpc.RpcError as rpc_error: @@ -489,10 +516,43 @@ def stop(self): if self._response_stream is not None: self._response_stream.cancel() if self._runLoop is not None: - self._runLoop.join(timeout=30) + self._runLoop.join(timeout=self._stop_timeout) self._async_worker_manager.shutdown() self._logger.info("Worker shutdown completed") self._is_running = False + self._ready.clear() + + def wait_for_idle(self, timeout: Optional[float] = None) -> bool: + """Block until no in-flight activities are executing. + In-Flight activities are activities that have been submitted to the worker but have not yet completed. + The workflow status could be Done, if the activity was not waited for + (like in when_any might not wait for all activities to complete) + + Returns True if idle within timeout; otherwise False. + """ + end: Optional[float] = None + if timeout is not None: + import time as _t + + end = _t.time() + timeout + with self._active_task_cv: + while self._active_task_count > 0: + remaining = None + if end is not None: + import time as _t + + remaining = max(0.0, end - _t.time()) + if remaining == 0.0: + return False + self._active_task_cv.wait(timeout=remaining) + return True + + def wait_for_ready(self, timeout: Optional[float] = None) -> bool: + """Block until the worker has an active connection to the sidecar. + + Returns True if the worker became ready within the timeout; otherwise False. + """ + return self._ready.wait(timeout) def _execute_orchestrator( self, @@ -527,6 +587,24 @@ def _execute_orchestrator( try: stub.CompleteOrchestratorTask(res) + except grpc.RpcError as rpc_error: # type: ignore + # During shutdown or if the instance was terminated, the channel may be closed + # or the instance may no longer be recognized by the sidecar. Treat these as benign. + code = rpc_error.code() # type: ignore + details = str(rpc_error) + benign = ( + code in {grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE} + or "unknown instance ID/task ID combo" in details + or "Channel closed" in details + ) + if self._shutdown.is_set() or benign: + self._logger.debug( + f"Ignoring orchestrator completion delivery error during shutdown/benign condition: {rpc_error}" + ) + else: + self._logger.exception( + f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {rpc_error}" + ) except Exception as ex: self._logger.exception( f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}" @@ -558,19 +636,42 @@ def _execute_activity( try: stub.CompleteActivityTask(res) + except grpc.RpcError as rpc_error: # type: ignore + # Treat common shutdown/termination races as benign to avoid noisy logs + code = rpc_error.code() # type: ignore + details = str(rpc_error) + benign = code in { + grpc.StatusCode.CANCELLED, + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.UNKNOWN, + } and ( + "unknown instance ID/task ID combo" in details + or "Channel closed" in details + or "Locally cancelled by application" in details + ) + if self._shutdown.is_set() or benign: + self._logger.debug( + f"Ignoring activity completion delivery error during shutdown/benign condition: {rpc_error}" + ) + else: + self._logger.exception( + f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {rpc_error}" + ) except Exception as ex: self._logger.exception( f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}" ) -class _RuntimeOrchestrationContext(task.OrchestrationContext): +class _RuntimeOrchestrationContext(task.OrchestrationContext, deterministic.DeterministicContextMixin): _generator: Optional[Generator[task.Task, Any, Any]] _previous_task: Optional[task.Task] def __init__(self, instance_id: str): + super().__init__() self._generator = None self._is_replaying = True + self._is_suspended = False self._is_complete = False self._result = None self._pending_actions: dict[int, pb.OrchestratorAction] = {} @@ -721,6 +822,10 @@ def current_utc_datetime(self, value: datetime): def is_replaying(self) -> bool: return self._is_replaying + @property + def is_suspended(self) -> bool: + return self._is_suspended + def set_custom_status(self, custom_status: Any) -> None: self._encoded_custom_status = ( shared.to_json(custom_status) if custom_status is not None else None @@ -802,7 +907,8 @@ def call_activity_function_helper( id = self.next_sequence_number() router = pb.TaskRouter() - router.sourceAppID = self._app_id + if self._app_id is not None: + router.sourceAppID = self._app_id if app_id is not None: router.targetAppID = app_id @@ -947,7 +1053,7 @@ def execute( return ExecutionResults(actions=actions, encoded_custom_status=ctx._encoded_custom_status) def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None: - if self._is_suspended and _is_suspendable(event): + if self._is_suspended and _is_suspendable(event) and not ctx.is_replaying: # We are suspended, so we need to buffer this event until we are resumed self._suspended_events.append(event) return @@ -1078,16 +1184,37 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven if isinstance(activity_task, task.RetryableTask): if activity_task._retry_policy is not None: - next_delay = activity_task.compute_next_delay() - if next_delay is None: + # Check for non-retryable errors by type name + error_type = event.taskFailed.failureDetails.errorType + policy = activity_task._retry_policy + is_non_retryable = False + if error_type == getattr( + task.NonRetryableError, "__name__", "NonRetryableError" + ): + is_non_retryable = True + elif ( + policy.non_retryable_error_types is not None + and error_type in policy.non_retryable_error_types + ): + is_non_retryable = True + + if is_non_retryable: activity_task.fail( f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", event.taskFailed.failureDetails, ) ctx.resume() else: - activity_task.increment_attempt_count() - ctx.create_timer_internal(next_delay, activity_task) + next_delay = activity_task.compute_next_delay() + if next_delay is None: + activity_task.fail( + f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", + event.taskFailed.failureDetails, + ) + ctx.resume() + else: + activity_task.increment_attempt_count() + ctx.create_timer_internal(next_delay, activity_task) elif isinstance(activity_task, task.CompletableTask): activity_task.fail( f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", @@ -1145,16 +1272,37 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven return if isinstance(sub_orch_task, task.RetryableTask): if sub_orch_task._retry_policy is not None: - next_delay = sub_orch_task.compute_next_delay() - if next_delay is None: + # Check for non-retryable errors by type name + error_type = failedEvent.failureDetails.errorType + policy = sub_orch_task._retry_policy + is_non_retryable = False + if error_type == getattr( + task.NonRetryableError, "__name__", "NonRetryableError" + ): + is_non_retryable = True + elif ( + policy.non_retryable_error_types is not None + and error_type in policy.non_retryable_error_types + ): + is_non_retryable = True + + if is_non_retryable: sub_orch_task.fail( f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", failedEvent.failureDetails, ) ctx.resume() else: - sub_orch_task.increment_attempt_count() - ctx.create_timer_internal(next_delay, sub_orch_task) + next_delay = sub_orch_task.compute_next_delay() + if next_delay is None: + sub_orch_task.fail( + f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", + failedEvent.failureDetails, + ) + ctx.resume() + else: + sub_orch_task.increment_attempt_count() + ctx.create_timer_internal(next_delay, sub_orch_task) elif isinstance(sub_orch_task, task.CompletableTask): sub_orch_task.fail( f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", @@ -1195,10 +1343,12 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven if not self._is_suspended and not ctx.is_replaying: self._logger.info(f"{ctx.instance_id}: Execution suspended.") self._is_suspended = True + ctx._is_suspended = True elif event.HasField("executionResumed") and self._is_suspended: if not ctx.is_replaying: self._logger.info(f"{ctx.instance_id}: Resuming execution.") self._is_suspended = False + ctx._is_suspended = False for e in self._suspended_events: self.process_event(ctx, e) self._suspended_events = [] diff --git a/examples/components/statestore.yaml b/examples/components/statestore.yaml new file mode 100644 index 0000000..a2b567a --- /dev/null +++ b/examples/components/statestore.yaml @@ -0,0 +1,16 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: statestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" + - name: actorStateStore + value: "true" + - name: keyPrefix + value: "workflow" \ No newline at end of file diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..3480ecb --- /dev/null +++ b/mypy.ini @@ -0,0 +1,48 @@ +[mypy] +# Global mypy configuration for durabletask-python + +# Target Python version +python_version = 3.9 + +# Directories to check +files = durabletask/ + +# Strict mode settings +strict = True +warn_return_any = True +warn_unused_configs = True +disallow_any_generics = True +disallow_subclassing_any = True +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +disallow_untyped_decorators = True +no_implicit_optional = True +warn_redundant_casts = True +warn_unused_ignores = True +warn_no_return = True +warn_unreachable = True + +# Error output +show_error_codes = True +show_column_numbers = True +pretty = True + +# Third-party library stubs +ignore_missing_imports = True + +# Specific module configurations +[mypy-durabletask.aio.*] +# Extra strict for the new asyncio module +strict = True +warn_return_any = True + +[mypy-durabletask.internal.*] +# Generated protobuf code - less strict +ignore_errors = True + +[mypy-tests.*] +# Test files - slightly more lenient +disallow_untyped_defs = False +disallow_incomplete_defs = False diff --git a/requirements.txt b/requirements.txt index 7b288f0..b6902e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -# requirements in pyproject.toml +# pyproject.toml has the dependencies for this project \ No newline at end of file diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index b671cf8..1cc97e4 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl from durabletask.internal.shared import get_default_host_address, get_grpc_channel @@ -140,3 +140,17 @@ def test_sync_channel_passes_base_options_and_max_lengths(): assert ("grpc.max_send_message_length", 1234) in opts assert ("grpc.max_receive_message_length", 5678) in opts assert ("grpc.primary_user_agent", "durabletask-tests") in opts + + +def test_taskhub_client_close_handles_exceptions(): + """Test that close() handles exceptions gracefully (edge case not easily testable in E2E).""" + with patch("durabletask.internal.shared.get_grpc_channel") as mock_get_channel: + mock_channel = MagicMock() + mock_channel.close.side_effect = Exception("close failed") + mock_get_channel.return_value = mock_channel + + from durabletask import client + + task_hub_client = client.TaskHubGrpcClient() + # Should not raise exception + task_hub_client.close() diff --git a/tests/durabletask/test_deterministic.py b/tests/durabletask/test_deterministic.py new file mode 100644 index 0000000..baf400d --- /dev/null +++ b/tests/durabletask/test_deterministic.py @@ -0,0 +1,453 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import random +import uuid +from datetime import datetime, timezone + +import pytest + +from durabletask.deterministic import ( + DeterminismSeed, + derive_seed, + deterministic_random, + deterministic_uuid4, + deterministic_uuid_v5, +) +from durabletask.worker import _RuntimeOrchestrationContext + + +class TestDeterminismSeed: + """Test DeterminismSeed dataclass and its methods.""" + + def test_to_int_produces_consistent_result(self): + """Test that to_int produces the same result for same inputs.""" + seed1 = DeterminismSeed(instance_id="test-123", orchestration_unix_ts=1234567890) + seed2 = DeterminismSeed(instance_id="test-123", orchestration_unix_ts=1234567890) + assert seed1.to_int() == seed2.to_int() + + def test_to_int_produces_different_results_for_different_instance_ids(self): + """Test that different instance IDs produce different seeds.""" + seed1 = DeterminismSeed(instance_id="test-123", orchestration_unix_ts=1234567890) + seed2 = DeterminismSeed(instance_id="test-456", orchestration_unix_ts=1234567890) + assert seed1.to_int() != seed2.to_int() + + def test_to_int_produces_different_results_for_different_timestamps(self): + """Test that different timestamps produce different seeds.""" + seed1 = DeterminismSeed(instance_id="test-123", orchestration_unix_ts=1234567890) + seed2 = DeterminismSeed(instance_id="test-123", orchestration_unix_ts=1234567891) + assert seed1.to_int() != seed2.to_int() + + def test_to_int_returns_positive_integer(self): + """Test that to_int returns a positive integer.""" + seed = DeterminismSeed(instance_id="test-123", orchestration_unix_ts=1234567890) + result = seed.to_int() + assert isinstance(result, int) + assert result >= 0 + + +class TestDeriveSeed: + """Test derive_seed function.""" + + def test_derive_seed_is_deterministic(self): + """Test that derive_seed produces consistent results.""" + instance_id = "test-instance" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + seed1 = derive_seed(instance_id, dt) + seed2 = derive_seed(instance_id, dt) + assert seed1 == seed2 + + def test_derive_seed_different_for_different_times(self): + """Test that different times produce different seeds.""" + instance_id = "test-instance" + dt1 = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + dt2 = datetime(2025, 1, 1, 12, 0, 1, tzinfo=timezone.utc) + seed1 = derive_seed(instance_id, dt1) + seed2 = derive_seed(instance_id, dt2) + assert seed1 != seed2 + + def test_derive_seed_handles_timezone_aware_datetime(self): + """Test that derive_seed works with timezone-aware datetimes.""" + instance_id = "test-instance" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + seed = derive_seed(instance_id, dt) + assert isinstance(seed, int) + + +class TestDeterministicRandom: + """Test deterministic_random function.""" + + def test_deterministic_random_returns_random_object(self): + """Test that deterministic_random returns a Random instance.""" + instance_id = "test-instance" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + rnd = deterministic_random(instance_id, dt) + assert isinstance(rnd, random.Random) + + def test_deterministic_random_produces_same_sequence(self): + """Test that same inputs produce same random sequence.""" + instance_id = "test-instance" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + rnd1 = deterministic_random(instance_id, dt) + rnd2 = deterministic_random(instance_id, dt) + + sequence1 = [rnd1.random() for _ in range(10)] + sequence2 = [rnd2.random() for _ in range(10)] + assert sequence1 == sequence2 + + def test_deterministic_random_different_for_different_inputs(self): + """Test that different inputs produce different sequences.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + rnd1 = deterministic_random("instance-1", dt) + rnd2 = deterministic_random("instance-2", dt) + + val1 = rnd1.random() + val2 = rnd2.random() + assert val1 != val2 + + +class TestDeterministicUuid4: + """Test deterministic_uuid4 function.""" + + def test_deterministic_uuid4_returns_valid_uuid(self): + """Test that deterministic_uuid4 returns a valid UUID4.""" + rnd = random.Random(42) + result = deterministic_uuid4(rnd) + assert isinstance(result, uuid.UUID) + assert result.version == 4 + + def test_deterministic_uuid4_is_deterministic(self): + """Test that same random state produces same UUID.""" + rnd1 = random.Random(42) + rnd2 = random.Random(42) + uuid1 = deterministic_uuid4(rnd1) + uuid2 = deterministic_uuid4(rnd2) + assert uuid1 == uuid2 + + def test_deterministic_uuid4_different_for_different_seeds(self): + """Test that different seeds produce different UUIDs.""" + rnd1 = random.Random(42) + rnd2 = random.Random(43) + uuid1 = deterministic_uuid4(rnd1) + uuid2 = deterministic_uuid4(rnd2) + assert uuid1 != uuid2 + + +class TestDeterministicUuidV5: + """Test deterministic_uuid_v5 function (matching .NET implementation).""" + + def test_deterministic_uuid_v5_returns_valid_uuid(self): + """Test that deterministic_uuid_v5 returns a valid UUID v5.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + result = deterministic_uuid_v5("test-instance", dt, 0) + assert isinstance(result, uuid.UUID) + assert result.version == 5 + + def test_deterministic_uuid_v5_is_deterministic(self): + """Test that same inputs produce same UUID.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + uuid1 = deterministic_uuid_v5("test-instance", dt, 0) + uuid2 = deterministic_uuid_v5("test-instance", dt, 0) + assert uuid1 == uuid2 + + def test_deterministic_uuid_v5_different_for_different_counters(self): + """Test that different counters produce different UUIDs.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + uuid1 = deterministic_uuid_v5("test-instance", dt, 0) + uuid2 = deterministic_uuid_v5("test-instance", dt, 1) + assert uuid1 != uuid2 + + def test_deterministic_uuid_v5_different_for_different_instance_ids(self): + """Test that different instance IDs produce different UUIDs.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + uuid1 = deterministic_uuid_v5("instance-1", dt, 0) + uuid2 = deterministic_uuid_v5("instance-2", dt, 0) + assert uuid1 != uuid2 + + def test_deterministic_uuid_v5_different_for_different_datetimes(self): + """Test that different datetimes produce different UUIDs.""" + dt1 = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + dt2 = datetime(2025, 1, 1, 12, 0, 1, tzinfo=timezone.utc) + uuid1 = deterministic_uuid_v5("test-instance", dt1, 0) + uuid2 = deterministic_uuid_v5("test-instance", dt2, 0) + assert uuid1 != uuid2 + + def test_deterministic_uuid_v5_matches_expected_format(self): + """Test that UUID v5 uses the correct namespace.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + result = deterministic_uuid_v5("test-instance", dt, 0) + # Should be deterministic - same inputs always produce same output + expected = deterministic_uuid_v5("test-instance", dt, 0) + assert result == expected + + def test_deterministic_uuid_v5_counter_sequence(self): + """Test that incrementing counter produces different UUIDs in sequence.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + uuids = [deterministic_uuid_v5("test-instance", dt, i) for i in range(5)] + # All should be different + assert len(set(uuids)) == 5 + # But calling with same counter should produce same UUID + assert uuids[0] == deterministic_uuid_v5("test-instance", dt, 0) + assert uuids[4] == deterministic_uuid_v5("test-instance", dt, 4) + + +def mock_deterministic_context(instance_id: str, current_utc_datetime: datetime) -> _RuntimeOrchestrationContext: + """Mock context for testing DeterministicContextMixin.""" + ctx = _RuntimeOrchestrationContext(instance_id) + ctx.current_utc_datetime = current_utc_datetime + return ctx + + +class TestDeterministicContextMixin: + """Test DeterministicContextMixin methods.""" + + def test_now_returns_current_utc_datetime(self): + """Test that now() returns the orchestration time.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + assert ctx.now() == dt + + def test_random_returns_deterministic_prng(self): + """Test that random() returns a deterministic PRNG.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + rnd1 = ctx.random() + rnd2 = ctx.random() + + # Both should produce same sequence + assert isinstance(rnd1, random.Random) + assert isinstance(rnd2, random.Random) + assert rnd1.random() == rnd2.random() + + def test_random_has_deterministic_marker(self): + """Test that random() sets _dt_deterministic marker.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + rnd = ctx.random() + assert hasattr(rnd, "_dt_deterministic") + assert rnd._dt_deterministic is True + + def test_uuid4_generates_deterministic_uuid(self): + """Test that uuid4() generates deterministic UUIDs v5 with counter.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx1 = mock_deterministic_context("test-instance", dt) + ctx2 = mock_deterministic_context("test-instance", dt) + + uuid1 = ctx1.uuid4() + uuid2 = ctx2.uuid4() + + assert isinstance(uuid1, uuid.UUID) + assert uuid1.version == 5 # Now using UUID v5 like .NET + assert uuid1 == uuid2 # Same counter (0) produces same UUID + + def test_uuid4_increments_counter(self): + """Test that uuid4() increments counter producing different UUIDs.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + uuid1 = ctx.uuid4() # counter=0 + uuid2 = ctx.uuid4() # counter=1 + uuid3 = ctx.uuid4() # counter=2 + + # All should be different due to counter + assert uuid1 != uuid2 + assert uuid2 != uuid3 + assert uuid1 != uuid3 + + def test_uuid4_counter_resets_on_replay(self): + """Test that counter resets on new context (simulating replay).""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + # First execution + ctx1 = mock_deterministic_context("test-instance", dt) + uuid1_first = ctx1.uuid4() # counter=0 + uuid1_second = ctx1.uuid4() # counter=1 + + # Replay - new context, counter resets + ctx2 = mock_deterministic_context("test-instance", dt) + uuid2_first = ctx2.uuid4() # counter=0 + uuid2_second = ctx2.uuid4() # counter=1 + + # Same counter values produce same UUIDs (determinism!) + assert uuid1_first == uuid2_first + assert uuid1_second == uuid2_second + + def test_new_guid_is_alias_for_uuid4(self): + """Test that new_guid() is an alias for uuid4().""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + guid1 = ctx.new_guid() # counter=0 + guid2 = ctx.uuid4() # counter=1 + + # Both should be v5 UUIDs, but different due to counter increment + assert isinstance(guid1, uuid.UUID) + assert isinstance(guid2, uuid.UUID) + assert guid1.version == 5 + assert guid2.version == 5 + assert guid1 != guid2 # Different due to counter + + # Verify determinism - same counter produces same UUID + ctx2 = mock_deterministic_context("test-instance", dt) + guid3 = ctx2.new_guid() # counter=0 + assert guid3 == guid1 # Same as first call + + def test_random_string_generates_string_of_correct_length(self): + """Test that random_string() generates string of specified length.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + s = ctx.random_string(10) + assert len(s) == 10 + + def test_random_string_is_deterministic(self): + """Test that random_string() produces consistent results.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx1 = mock_deterministic_context("test-instance", dt) + ctx2 = mock_deterministic_context("test-instance", dt) + + s1 = ctx1.random_string(20) + s2 = ctx2.random_string(20) + assert s1 == s2 + + def test_random_string_uses_default_alphabet(self): + """Test that random_string() uses alphanumeric characters by default.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + s = ctx.random_string(100) + assert all(c.isalnum() for c in s) + + def test_random_string_uses_custom_alphabet(self): + """Test that random_string() respects custom alphabet.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + s = ctx.random_string(50, alphabet="ABC") + assert all(c in "ABC" for c in s) + + def test_random_string_raises_on_negative_length(self): + """Test that random_string() raises ValueError for negative length.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + with pytest.raises(ValueError, match="length must be non-negative"): + ctx.random_string(-1) + + def test_random_string_raises_on_empty_alphabet(self): + """Test that random_string() raises ValueError for empty alphabet.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + with pytest.raises(ValueError, match="alphabet must not be empty"): + ctx.random_string(10, alphabet="") + + def test_random_string_handles_zero_length(self): + """Test that random_string() handles zero length correctly.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + s = ctx.random_string(0) + assert s == "" + + def test_random_int_generates_int_in_range(self): + """Test that random_int() generates integer in specified range.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + for _ in range(10): + val = ctx.random_int(10, 20) + assert 10 <= val <= 20 + + def test_random_int_is_deterministic(self): + """Test that random_int() produces consistent results.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx1 = mock_deterministic_context("test-instance", dt) + ctx2 = mock_deterministic_context("test-instance", dt) + + val1 = ctx1.random_int(0, 1000) + val2 = ctx2.random_int(0, 1000) + assert val1 == val2 + + def test_random_int_uses_default_range(self): + """Test that random_int() uses default range when not specified.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + val = ctx.random_int() + assert 0 <= val <= 2**31 - 1 + + def test_random_int_raises_on_invalid_range(self): + """Test that random_int() raises ValueError when min > max.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + with pytest.raises(ValueError, match="min_value must be <= max_value"): + ctx.random_int(20, 10) + + def test_random_int_handles_same_min_and_max(self): + """Test that random_int() handles case where min equals max.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + val = ctx.random_int(42, 42) + assert val == 42 + + def test_random_choice_picks_from_sequence(self): + """Test that random_choice() picks element from sequence.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + choices = ["a", "b", "c", "d", "e"] + result = ctx.random_choice(choices) + assert result in choices + + def test_random_choice_is_deterministic(self): + """Test that random_choice() produces consistent results.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx1 = mock_deterministic_context("test-instance", dt) + ctx2 = mock_deterministic_context("test-instance", dt) + + choices = list(range(100)) + result1 = ctx1.random_choice(choices) + result2 = ctx2.random_choice(choices) + assert result1 == result2 + + def test_random_choice_raises_on_empty_sequence(self): + """Test that random_choice() raises IndexError for empty sequence.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + with pytest.raises(IndexError, match="Cannot choose from empty sequence"): + ctx.random_choice([]) + + def test_random_choice_works_with_different_sequence_types(self): + """Test that random_choice() works with various sequence types.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + # List + result = ctx.random_choice([1, 2, 3]) + assert result in [1, 2, 3] + + # Reset context for deterministic behavior + ctx = mock_deterministic_context("test-instance", dt) + # Tuple + result = ctx.random_choice((1, 2, 3)) + assert result in (1, 2, 3) + + # Reset context for deterministic behavior + ctx = mock_deterministic_context("test-instance", dt) + # String + result = ctx.random_choice("abc") + assert result in "abc" diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 225456d..5b5b85b 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -5,6 +5,7 @@ import threading import time from datetime import timedelta +from typing import Optional import pytest @@ -16,6 +17,32 @@ pytestmark = pytest.mark.e2e +def _wait_until_terminal( + hub_client: client.TaskHubGrpcClient, + instance_id: str, + *, + timeout_s: int = 30, + fetch_payloads: bool = True, +) -> Optional[client.OrchestrationState]: + """Polling-based completion wait that does not rely on the completion stream. + + Returns the terminal state or None if timeout. + """ + deadline = time.time() + timeout_s + delay = 0.1 + while time.time() < deadline: + st = hub_client.get_orchestration_state(instance_id, fetch_payloads=fetch_payloads) + if st and st.runtime_status in ( + client.OrchestrationStatus.COMPLETED, + client.OrchestrationStatus.FAILED, + client.OrchestrationStatus.TERMINATED, + ): + return st + time.sleep(delay) + delay = min(delay * 1.5, 1.0) + return None + + def test_empty_orchestration(): invoked = False @@ -31,12 +58,18 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): with worker.TaskHubGrpcWorker(channel_options=channel_options) as w: w.add_orchestrator(empty_orchestrator) w.start() + w.wait_for_ready(timeout=10) # set a custom max send length option c = client.TaskHubGrpcClient(channel_options=channel_options) id = c.schedule_new_orchestration(empty_orchestrator) state = c.wait_for_orchestration_completion(id, timeout=30) + # Test calling wait again on already-completed orchestration (should return immediately) + state2 = c.wait_for_orchestration_completion(id, timeout=30) + assert state2 is not None + assert state2.runtime_status == client.OrchestrationStatus.COMPLETED + assert invoked assert state is not None assert state.name == task.get_name(empty_orchestrator) @@ -48,6 +81,41 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert state.serialized_custom_status is None +def test_wait_for_idle(): + """Test that wait_for_idle properly waits for in-flight activities to complete.""" + import time + + def slow_activity(ctx: task.ActivityContext, input: int): + # Simulate slow activity + time.sleep(0.2) + return input + 1 + + def orchestrator(ctx: task.OrchestrationContext, input: int): + # Schedule multiple activities without waiting + tasks = [ctx.call_activity(slow_activity, input=i) for i in range(3)] + # Wait for all to complete + results = yield task.when_all(tasks) + return sum(results) + + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: + w.add_orchestrator(orchestrator) + w.add_activity(slow_activity) + w.start() + w.wait_for_ready(timeout=10) + + with client.TaskHubGrpcClient() as c: + id = c.schedule_new_orchestration(orchestrator, input=1) + + # Wait for orchestration to complete + state = c.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + # Wait for any lingering activities to finish + idle = w.wait_for_idle(timeout=5.0) + assert idle is True + + def test_activity_sequence(): def plus_one(_: task.ActivityContext, input: int) -> int: return input + 1 @@ -61,14 +129,15 @@ def sequence(ctx: task.OrchestrationContext, start_val: int): return numbers # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(sequence) w.add_activity(plus_one) w.start() + w.wait_for_ready(timeout=10) - task_hub_client = client.TaskHubGrpcClient() - id = task_hub_client.schedule_new_orchestration(sequence, input=1) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient() as task_hub_client: + id = task_hub_client.schedule_new_orchestration(sequence, input=1) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(sequence) @@ -104,15 +173,16 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): return error_msg # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.add_activity(throw) w.add_activity(increment_counter) w.start() + w.wait_for_ready(timeout=10) - task_hub_client = client.TaskHubGrpcClient() - id = task_hub_client.schedule_new_orchestration(orchestrator, input=1) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient() as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator, input=1) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(orchestrator) @@ -146,15 +216,16 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): yield task.when_all(tasks) # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_activity(increment) w.add_orchestrator(orchestrator_child) w.add_orchestrator(parent_orchestrator) w.start() + w.wait_for_ready(timeout=10) - task_hub_client = client.TaskHubGrpcClient() - id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=10) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient() as task_hub_client: + id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=10) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -170,9 +241,10 @@ def orchestrator(ctx: task.OrchestrationContext, _): return [a, b, c] # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() + w.wait_for_ready(timeout=10) # Start the orchestration and immediately raise events to it. task_hub_client = client.TaskHubGrpcClient() @@ -199,16 +271,17 @@ def orchestrator(ctx: task.OrchestrationContext, _): return "timed out" # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() + w.wait_for_ready(timeout=10) # Start the orchestration and immediately raise events to it. - task_hub_client = client.TaskHubGrpcClient() - id = task_hub_client.schedule_new_orchestration(orchestrator) - if raise_event: - task_hub_client.raise_orchestration_event(id, "Approval") - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient() as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + if raise_event: + task_hub_client.raise_orchestration_event(id, "Approval") + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -224,37 +297,37 @@ def orchestrator(ctx: task.OrchestrationContext, _): return result # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() - - task_hub_client = client.TaskHubGrpcClient() - id = task_hub_client.schedule_new_orchestration(orchestrator) - state = task_hub_client.wait_for_orchestration_start(id, timeout=30) - assert state is not None - - # Suspend the orchestration and wait for it to go into the SUSPENDED state - task_hub_client.suspend_orchestration(id) - while state.runtime_status == client.OrchestrationStatus.RUNNING: - time.sleep(0.1) - state = task_hub_client.get_orchestration_state(id) + w.wait_for_ready(timeout=10) + with client.TaskHubGrpcClient() as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + state = task_hub_client.wait_for_orchestration_start(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.SUSPENDED - - # Raise an event to the orchestration and confirm that it does NOT complete - task_hub_client.raise_orchestration_event(id, "my_event", data=42) - try: - state = task_hub_client.wait_for_orchestration_completion(id, timeout=3) - assert False, "Orchestration should not have completed" - except TimeoutError: - pass - # Resume the orchestration and wait for it to complete - task_hub_client.resume_orchestration(id) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED - assert state.serialized_output == json.dumps(42) + # Suspend the orchestration and wait for it to go into the SUSPENDED state + task_hub_client.suspend_orchestration(id) + while state.runtime_status == client.OrchestrationStatus.RUNNING: + time.sleep(0.1) + state = task_hub_client.get_orchestration_state(id) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.SUSPENDED + + # Raise an event to the orchestration and confirm that it does NOT complete + task_hub_client.raise_orchestration_event(id, "my_event", data=42) + try: + state = task_hub_client.wait_for_orchestration_completion(id, timeout=3) + assert False, "Orchestration should not have completed" + except TimeoutError: + pass + + # Resume the orchestration and wait for it to complete + task_hub_client.resume_orchestration(id) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(42) def test_terminate(): @@ -263,27 +336,27 @@ def orchestrator(ctx: task.OrchestrationContext, _): return result # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() + w.wait_for_ready(timeout=10) + with client.TaskHubGrpcClient() as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + state = task_hub_client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.RUNNING - task_hub_client = client.TaskHubGrpcClient() - id = task_hub_client.schedule_new_orchestration(orchestrator) - state = task_hub_client.wait_for_orchestration_start(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.RUNNING - - task_hub_client.terminate_orchestration(id, output="some reason for termination") - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.TERMINATED - assert state.serialized_output == json.dumps("some reason for termination") + task_hub_client.terminate_orchestration(id, output="some reason for termination") + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.TERMINATED + assert state.serialized_output == json.dumps("some reason for termination") def test_terminate_recursive(): thread_lock = threading.Lock() activity_counter = 0 - delay_time = 4 # seconds + delay_time = 2 # seconds (already optimized from 4s - don't reduce further as it can leads to failure) def increment(ctx, _): with thread_lock: @@ -303,36 +376,39 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): yield task.when_all(tasks) for recurse in [True, False]: - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_activity(increment) w.add_orchestrator(orchestrator_child) w.add_orchestrator(parent_orchestrator) w.start() + w.wait_for_ready(timeout=10) + with client.TaskHubGrpcClient() as task_hub_client: + instance_id = task_hub_client.schedule_new_orchestration( + parent_orchestrator, input=5 + ) - task_hub_client = client.TaskHubGrpcClient() - instance_id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=5) - - time.sleep(2) - - output = "Recursive termination = {recurse}" - task_hub_client.terminate_orchestration(instance_id, output=output, recursive=recurse) - - metadata = task_hub_client.wait_for_orchestration_completion(instance_id, timeout=30) - - assert metadata is not None - assert metadata.runtime_status == client.OrchestrationStatus.TERMINATED - assert metadata.serialized_output == f'"{output}"' - - time.sleep(delay_time) + time.sleep(1) # Brief delay to let orchestrations start - if recurse: - assert activity_counter == 0, ( - "Activity should not have executed with recursive termination" + output = "Recursive termination = {recurse}" + task_hub_client.terminate_orchestration( + instance_id, output=output, recursive=recurse ) - else: - assert activity_counter == 5, ( - "Activity should have executed without recursive termination" + + metadata = task_hub_client.wait_for_orchestration_completion( + instance_id, timeout=30 ) + assert metadata is not None + assert metadata.runtime_status == client.OrchestrationStatus.TERMINATED + assert metadata.serialized_output == f'"{output}"' + time.sleep(delay_time) # Wait for timer to check activity execution + if recurse: + assert activity_counter == 0, ( + "Activity should not have executed with recursive termination" + ) + else: + assert activity_counter == 5, ( + "Activity should have executed without recursive termination" + ) def test_continue_as_new(): @@ -351,9 +427,10 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): return all_results # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() + w.wait_for_ready(timeout=10) task_hub_client = client.TaskHubGrpcClient() id = task_hub_client.schedule_new_orchestration(orchestrator, input=0) @@ -391,7 +468,7 @@ def orchestrator(ctx: task.OrchestrationContext, counter: int): else: return {"counter": counter, "processed": processed, "all_results": activity_results} - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_activity(double_activity) w.add_orchestrator(orchestrator) w.start() @@ -424,13 +501,13 @@ def test_retry_policies(): child_orch_counter = 0 throw_activity_counter = 0 - # Second setup: With retry policies + # Second setup: With retry policies (minimal delays for faster tests) retry_policy = task.RetryPolicy( - first_retry_interval=timedelta(seconds=1), + first_retry_interval=timedelta(seconds=0.05), # 0.1 → 0.05 (50% faster) max_number_of_attempts=3, backoff_coefficient=1, - max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=30), + max_retry_interval=timedelta(seconds=0.5), # 1 → 0.5 + retry_timeout=timedelta(seconds=2), # 3 → 2 ) def parent_orchestrator_with_retry(ctx: task.OrchestrationContext, _): @@ -449,11 +526,12 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): throw_activity_counter += 1 raise RuntimeError("Kah-BOOOOM!!!") - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(parent_orchestrator_with_retry) w.add_orchestrator(child_orchestrator_with_retry) w.add_activity(throw_activity_with_retry) w.start() + w.wait_for_ready(timeout=10) task_hub_client = client.TaskHubGrpcClient() id = task_hub_client.schedule_new_orchestration(parent_orchestrator_with_retry) @@ -468,19 +546,47 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): assert throw_activity_counter == 9 assert child_orch_counter == 3 + # Test 2: Verify NonRetryableError prevents retries even with retry policy + non_retryable_counter = 0 + + def throw_non_retryable(ctx: task.ActivityContext, _): + nonlocal non_retryable_counter + non_retryable_counter += 1 + raise task.NonRetryableError("Cannot retry this!") + + def orchestrator_with_non_retryable(ctx: task.OrchestrationContext, _): + # Even with retry policy, NonRetryableError should fail immediately + yield ctx.call_activity(throw_non_retryable, retry_policy=retry_policy) + + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: + w.add_orchestrator(orchestrator_with_non_retryable) + w.add_activity(throw_non_retryable) + w.start() + w.wait_for_ready(timeout=10) + + task_hub_client = client.TaskHubGrpcClient() + id = task_hub_client.schedule_new_orchestration(orchestrator_with_non_retryable) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.failure_details is not None + assert "Cannot retry this!" in state.failure_details.message + # Key assertion: activity was called exactly once (no retries) + assert non_retryable_counter == 1 + def test_retry_timeout(): # This test verifies that the retry timeout is working as expected. - # Max number of attempts is 5 and retry timeout is 14 seconds. - # Total seconds consumed till 4th attempt is 1 + 2 + 4 + 8 = 15 seconds. - # So, the 5th attempt should not be made and the orchestration should fail. + # Max number of attempts is 5 and retry timeout is 1.7 seconds. + # Delays: 0.25 + 0.5 + 1.0 = 1.75 seconds cumulative before 4th attempt. + # So, the 5th attempt (which would happen at 1.75s) should not be made. throw_activity_counter = 0 retry_policy = task.RetryPolicy( first_retry_interval=timedelta(seconds=1), max_number_of_attempts=5, backoff_coefficient=2, max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=14), + retry_timeout=timedelta(seconds=13), # Set just before 4th attempt ) def mock_orchestrator(ctx: task.OrchestrationContext, _): @@ -491,10 +597,11 @@ def throw_activity(ctx: task.ActivityContext, _): throw_activity_counter += 1 raise RuntimeError("Kah-BOOOOM!!!") - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(mock_orchestrator) w.add_activity(throw_activity) w.start() + w.wait_for_ready(timeout=10) task_hub_client = client.TaskHubGrpcClient() id = task_hub_client.schedule_new_orchestration(mock_orchestrator) @@ -513,7 +620,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): ctx.set_custom_status("foobaz") # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(empty_orchestrator) w.start() @@ -529,3 +636,181 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert state.serialized_input is None assert state.serialized_output is None assert state.serialized_custom_status == '"foobaz"' + + +def test_now_with_sequence_ordering(): + """ + Test that now_with_sequence() maintains strict ordering across workflow execution. + + This verifies: + 1. Timestamps increment sequentially + 2. Order is preserved across activity calls + 3. Deterministic behavior (timestamps are consistent on replay) + """ + + def simple_activity(ctx, input_val: str): + return f"activity_{input_val}_done" + + def timestamp_ordering_workflow(ctx: task.OrchestrationContext, _): + timestamps = [] + + # First timestamp before any activities + t1 = ctx.now_with_sequence() + timestamps.append(("t1_before_activities", t1.isoformat())) + + # Call first activity + result1 = yield ctx.call_activity(simple_activity, input="first") + timestamps.append(("activity_1_result", result1)) + + # Timestamp after first activity + t2 = ctx.now_with_sequence() + timestamps.append(("t2_after_activity_1", t2.isoformat())) + + # Call second activity + result2 = yield ctx.call_activity(simple_activity, input="second") + timestamps.append(("activity_2_result", result2)) + + # Timestamp after second activity + t3 = ctx.now_with_sequence() + timestamps.append(("t3_after_activity_2", t3.isoformat())) + + # A few more rapid timestamps to test counter incrementing + t4 = ctx.now_with_sequence() + timestamps.append(("t4_rapid", t4.isoformat())) + + t5 = ctx.now_with_sequence() + timestamps.append(("t5_rapid", t5.isoformat())) + + # Return all timestamps for verification + return { + "timestamps": timestamps, + "t1": t1.isoformat(), + "t2": t2.isoformat(), + "t3": t3.isoformat(), + "t4": t4.isoformat(), + "t5": t5.isoformat(), + } + + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: + w.add_orchestrator(timestamp_ordering_workflow) + w.add_activity(simple_activity) + w.start() + w.wait_for_ready(timeout=10) + + with client.TaskHubGrpcClient() as c: + instance_id = c.schedule_new_orchestration(timestamp_ordering_workflow) + state = c.wait_for_orchestration_completion( + instance_id, timeout=30, fetch_payloads=True + ) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + + # Parse result + result = json.loads(state.serialized_output) + assert result is not None + + # Verify all timestamps are present + assert "t1" in result + assert "t2" in result + assert "t3" in result + assert "t4" in result + assert "t5" in result + + # Parse timestamps back to datetime objects for comparison + from datetime import datetime + + t1 = datetime.fromisoformat(result["t1"]) + t2 = datetime.fromisoformat(result["t2"]) + t3 = datetime.fromisoformat(result["t3"]) + t4 = datetime.fromisoformat(result["t4"]) + t5 = datetime.fromisoformat(result["t5"]) + + # Verify strict ordering: t1 < t2 < t3 < t4 < t5 + # This is the key guarantee - timestamps must maintain order for tracing + assert t1 < t2, f"t1 ({t1}) should be < t2 ({t2})" + assert t2 < t3, f"t2 ({t2}) should be < t3 ({t3})" + assert t3 < t4, f"t3 ({t3}) should be < t4 ({t4})" + assert t4 < t5, f"t4 ({t4}) should be < t5 ({t5})" + + # Verify that timestamps called in rapid succession (t3, t4, t5 with no activities between) + # have exactly 1 microsecond deltas. These happen within the same replay execution. + delta_t3_t4 = (t4 - t3).total_seconds() * 1_000_000 + delta_t4_t5 = (t5 - t4).total_seconds() * 1_000_000 + + assert delta_t3_t4 == 1.0, f"t3 to t4 should be 1 microsecond, got {delta_t3_t4}" + assert delta_t4_t5 == 1.0, f"t4 to t5 should be 1 microsecond, got {delta_t4_t5}" + + # Note: We don't check exact deltas for t1->t2 or t2->t3 because they span + # activity calls. During replay, current_utc_datetime changes based on event + # timestamps, so the base time shifts. However, ordering is still guaranteed. + + +def test_cannot_add_orchestrator_while_running(): + """Test that orchestrators cannot be added while the worker is running.""" + def orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + def another_orchestrator(ctx: task.OrchestrationContext, _): + return "another" + + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: + w.add_orchestrator(orchestrator) + w.start() + w.wait_for_ready(timeout=10) + + # Try to add another orchestrator while running + with pytest.raises(RuntimeError, match="Orchestrators cannot be added while the worker is running"): + w.add_orchestrator(another_orchestrator) + + +def test_cannot_add_activity_while_running(): + """Test that activities cannot be added while the worker is running.""" + def activity(ctx: task.ActivityContext, input): + return input + + def another_activity(ctx: task.ActivityContext, input): + return input * 2 + + def orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: + w.add_orchestrator(orchestrator) + w.add_activity(activity) + w.start() + w.wait_for_ready(timeout=10) + + # Try to add another activity while running + with pytest.raises(RuntimeError, match="Activities cannot be added while the worker is running"): + w.add_activity(another_activity) + + +def test_can_add_functions_after_stop(): + """Test that orchestrators/activities can be added after stopping the worker.""" + def orchestrator1(ctx: task.OrchestrationContext, _): + return "done" + + def orchestrator2(ctx: task.OrchestrationContext, _): + return "done2" + + def activity(ctx: task.ActivityContext, input): + return input + + w = worker.TaskHubGrpcWorker(stop_timeout=2.0) + w.add_orchestrator(orchestrator1) + w.start() + w.wait_for_ready(timeout=10) + + c = client.TaskHubGrpcClient() + id = c.schedule_new_orchestration(orchestrator1) + state = c.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + w.stop() + + # Should be able to add after stop + w.add_orchestrator(orchestrator2) + w.add_activity(activity) diff --git a/tests/durabletask/test_orchestration_e2e_async.py b/tests/durabletask/test_orchestration_e2e_async.py index c441bdc..b71e70b 100644 --- a/tests/durabletask/test_orchestration_e2e_async.py +++ b/tests/durabletask/test_orchestration_e2e_async.py @@ -110,7 +110,7 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): return error_msg # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.add_activity(throw) w.add_activity(increment_counter) @@ -153,7 +153,7 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): yield task.when_all(tasks) # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_activity(increment) w.add_orchestrator(orchestrator_child) w.add_orchestrator(parent_orchestrator) @@ -178,7 +178,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): return [a, b, c] # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() @@ -208,7 +208,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): return "timed out" # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() @@ -234,7 +234,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): return result # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() # there could be a race condition if the workflow is scheduled before orchestrator is started @@ -275,7 +275,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): return result # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() @@ -302,7 +302,7 @@ def child(ctx: task.OrchestrationContext, _): return result # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(root) w.add_orchestrator(child) w.start() @@ -345,7 +345,7 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): return all_results # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(orchestrator) w.start() @@ -376,13 +376,13 @@ async def test_retry_policies(): child_orch_counter = 0 throw_activity_counter = 0 - # Second setup: With retry policies + # Second setup: With retry policies (minimal delays for faster tests) retry_policy = task.RetryPolicy( - first_retry_interval=timedelta(seconds=1), + first_retry_interval=timedelta(seconds=0.05), # 0.1 → 0.05 (50% faster) max_number_of_attempts=3, backoff_coefficient=1, - max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=30), + max_retry_interval=timedelta(seconds=0.5), # 1 → 0.5 + retry_timeout=timedelta(seconds=2), # 3 → 2 ) def parent_orchestrator_with_retry(ctx: task.OrchestrationContext, _): @@ -401,7 +401,7 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): throw_activity_counter += 1 raise RuntimeError("Kah-BOOOOM!!!") - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(parent_orchestrator_with_retry) w.add_orchestrator(child_orchestrator_with_retry) w.add_activity(throw_activity_with_retry) @@ -423,16 +423,16 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): async def test_retry_timeout(): # This test verifies that the retry timeout is working as expected. - # Max number of attempts is 5 and retry timeout is 14 seconds. - # Total seconds consumed till 4th attempt is 1 + 2 + 4 + 8 = 15 seconds. - # So, the 5th attempt should not be made and the orchestration should fail. + # Max number of attempts is 5 and retry timeout is 1.7 seconds. + # Delays: 0.25 + 0.5 + 1.0 = 1.75 seconds cumulative before 4th attempt. + # So, the 5th attempt (which would happen at 1.75s) should not be made. throw_activity_counter = 0 retry_policy = task.RetryPolicy( first_retry_interval=timedelta(seconds=1), max_number_of_attempts=5, backoff_coefficient=2, max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=14), + retry_timeout=timedelta(seconds=13), # Set just before 4th attempt ) def mock_orchestrator(ctx: task.OrchestrationContext, _): @@ -443,7 +443,7 @@ def throw_activity(ctx: task.ActivityContext, _): throw_activity_counter += 1 raise RuntimeError("Kah-BOOOOM!!!") - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(mock_orchestrator) w.add_activity(throw_activity) w.start() @@ -465,7 +465,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): ctx.set_custom_status("foobaz") # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: w.add_orchestrator(empty_orchestrator) w.start() diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index 964512f..bf81f26 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -3,7 +3,7 @@ import json import logging -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import pytest @@ -826,7 +826,7 @@ def test_nondeterminism_expected_sub_orchestration_task_completion_wrong_task_ty def orchestrator(ctx: task.OrchestrationContext, _): result = yield ctx.create_timer( - datetime.utcnow() + datetime.now(timezone.utc) ) # created timer but history expects sub-orchestration return result @@ -920,7 +920,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Complete the timer task. The orchestration should move to the wait_for_external_event step, which # should then complete immediately because the event was buffered in the old event history. - timer_due_time = datetime.utcnow() + timedelta(days=1) + timer_due_time = datetime.now(timezone.utc) + timedelta(days=1) old_events = new_events + [helpers.new_timer_created_event(1, timer_due_time)] new_events = [helpers.new_timer_fired_event(1, timer_due_time)] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) @@ -1013,9 +1013,9 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): helpers.new_event_raised_event("my_event", encoded_input="42"), helpers.new_event_raised_event("my_event", encoded_input="43"), helpers.new_event_raised_event("my_event", encoded_input="44"), - helpers.new_timer_created_event(1, datetime.utcnow() + timedelta(days=1)), + helpers.new_timer_created_event(1, datetime.now(timezone.utc) + timedelta(days=1)), ] - new_events = [helpers.new_timer_fired_event(1, datetime.utcnow() + timedelta(days=1))] + new_events = [helpers.new_timer_fired_event(1, datetime.now(timezone.utc) + timedelta(days=1))] # Execute the orchestration. It should be in a running state waiting for the timer to fire executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) @@ -1447,6 +1447,261 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert str(ex) in complete_action.failureDetails.errorMessage +def test_activity_non_retryable_default_exception(): + """If activity fails with NonRetryableError, it should not be retried and orchestration should fail immediately.""" + + def dummy_activity(ctx, _): + raise task.NonRetryableError("boom") + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity( + dummy_activity, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=1, + ), + ) + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + current_timestamp = datetime.utcnow() + old_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + ] + new_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_task_failed_event(1, task.NonRetryableError("boom")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + complete_action = get_and_validate_single_complete_orchestration_action(actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED + assert complete_action.failureDetails.errorMessage.__contains__("Activity task #1 failed: boom") + + +def test_activity_non_retryable_policy_name(): + """If policy marks ValueError as non-retryable (by name), fail immediately without retry.""" + + def dummy_activity(ctx, _): + raise ValueError("boom") + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity( + dummy_activity, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=5, + non_retryable_error_types=["ValueError"], + ), + ) + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + current_timestamp = datetime.utcnow() + old_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + ] + new_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_task_failed_event(1, ValueError("boom")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + complete_action = get_and_validate_single_complete_orchestration_action(actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED + assert complete_action.failureDetails.errorMessage.__contains__("Activity task #1 failed: boom") + + +def test_activity_generic_exception_is_retryable(): + """Verify that generic Exception is retryable by default (not treated as non-retryable).""" + + def dummy_activity(ctx, _): + raise Exception("generic error") + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity( + dummy_activity, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=1, + ), + ) + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + current_timestamp = datetime.utcnow() + # First attempt fails + old_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + ] + new_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_task_failed_event(1, Exception("generic error")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + # Should schedule a retry timer, not fail immediately + assert len(actions) == 1 + assert actions[0].HasField("createTimer") + assert actions[0].id == 2 + + # Simulate the timer firing and activity being rescheduled + expected_fire_at = current_timestamp + timedelta(seconds=1) + old_events = old_events + new_events + current_timestamp = expected_fire_at + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_timer_fired_event(2, current_timestamp), + ] + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + assert len(actions) == 2 # timer + rescheduled task + assert actions[1].HasField("scheduleTask") + assert actions[1].id == 1 + + # Second attempt also fails + old_events = old_events + new_events + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_task_failed_event(1, Exception("generic error")), + ] + + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + # Should schedule another retry timer + assert len(actions) == 3 + assert actions[2].HasField("createTimer") + assert actions[2].id == 3 + + # Simulate the timer firing and activity being rescheduled + expected_fire_at = current_timestamp + timedelta(seconds=1) + old_events = old_events + new_events + current_timestamp = expected_fire_at + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_timer_fired_event(3, current_timestamp), + ] + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + assert len(actions) == 3 # timer + rescheduled task + assert actions[1].HasField("scheduleTask") + assert actions[1].id == 1 + + # Third attempt fails - should exhaust retries + old_events = old_events + new_events + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_task_failed_event(1, Exception("generic error")), + ] + + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + # Now should fail - no more retries + complete_action = get_and_validate_single_complete_orchestration_action(actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED + assert complete_action.failureDetails.errorMessage.__contains__( + "Activity task #1 failed: generic error" + ) + + +def test_sub_orchestration_non_retryable_default_exception(): + """If sub-orchestrator fails with NonRetryableError, do not retry and fail immediately.""" + + def child(ctx: task.OrchestrationContext, _): + pass + + def parent(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator( + child, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + ), + ) + + registry = worker._Registry() + child_name = registry.add_orchestrator(child) + parent_name = registry.add_orchestrator(parent) + + current_timestamp = datetime.utcnow() + old_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_execution_started_event(parent_name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_sub_orchestration_created_event(1, child_name, "sub-1", encoded_input=None), + ] + new_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_sub_orchestration_failed_event(1, task.NonRetryableError("boom")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + complete_action = get_and_validate_single_complete_orchestration_action(actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED + assert complete_action.failureDetails.errorMessage.__contains__( + "Sub-orchestration task #1 failed: boom" + ) + + +def test_sub_orchestration_non_retryable_policy_type(): + """If policy marks ValueError as non-retryable (by class), fail immediately without retry.""" + + def child(ctx: task.OrchestrationContext, _): + pass + + def parent(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator( + child, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=5, + non_retryable_error_types=[ValueError], + ), + ) + + registry = worker._Registry() + child_name = registry.add_orchestrator(child) + parent_name = registry.add_orchestrator(parent) + + current_timestamp = datetime.utcnow() + old_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_execution_started_event(parent_name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_sub_orchestration_created_event(1, child_name, "sub-1", encoded_input=None), + ] + new_events = [ + helpers.new_orchestrator_started_event(timestamp=current_timestamp), + helpers.new_sub_orchestration_failed_event(1, ValueError("boom")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + complete_action = get_and_validate_single_complete_orchestration_action(actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED + assert complete_action.failureDetails.errorMessage.__contains__( + "Sub-orchestration task #1 failed: boom" + ) + + def get_and_validate_single_complete_orchestration_action( actions: list[pb.OrchestratorAction], ) -> pb.CompleteOrchestrationAction: diff --git a/tests/durabletask/test_registry.py b/tests/durabletask/test_registry.py new file mode 100644 index 0000000..150a870 --- /dev/null +++ b/tests/durabletask/test_registry.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for the _Registry class validation logic.""" + +import pytest + +from durabletask import worker + + +def test_registry_add_orchestrator_none(): + """Test that adding a None orchestrator raises ValueError.""" + registry = worker._Registry() + + with pytest.raises(ValueError, match="An orchestrator function argument is required"): + registry.add_orchestrator(None) + + +def test_registry_add_named_orchestrator_empty_name(): + """Test that adding an orchestrator with empty name raises ValueError.""" + registry = worker._Registry() + + def dummy_orchestrator(ctx, input): + return "done" + + with pytest.raises(ValueError, match="A non-empty orchestrator name is required"): + registry.add_named_orchestrator("", dummy_orchestrator) + + +def test_registry_add_orchestrator_duplicate(): + """Test that adding a duplicate orchestrator raises ValueError.""" + registry = worker._Registry() + + def dummy_orchestrator(ctx, input): + return "done" + + name = "test_orchestrator" + registry.add_named_orchestrator(name, dummy_orchestrator) + + with pytest.raises(ValueError, match=f"A '{name}' orchestrator already exists"): + registry.add_named_orchestrator(name, dummy_orchestrator) + + +def test_registry_add_activity_none(): + """Test that adding a None activity raises ValueError.""" + registry = worker._Registry() + + with pytest.raises(ValueError, match="An activity function argument is required"): + registry.add_activity(None) + + +def test_registry_add_named_activity_empty_name(): + """Test that adding an activity with empty name raises ValueError.""" + registry = worker._Registry() + + def dummy_activity(ctx, input): + return "done" + + with pytest.raises(ValueError, match="A non-empty activity name is required"): + registry.add_named_activity("", dummy_activity) + + +def test_registry_add_activity_duplicate(): + """Test that adding a duplicate activity raises ValueError.""" + registry = worker._Registry() + + def dummy_activity(ctx, input): + return "done" + + name = "test_activity" + registry.add_named_activity(name, dummy_activity) + + with pytest.raises(ValueError, match=f"A '{name}' activity already exists"): + registry.add_named_activity(name, dummy_activity) + + +def test_registry_get_orchestrator_exists(): + """Test retrieving an existing orchestrator.""" + registry = worker._Registry() + + def dummy_orchestrator(ctx, input): + return "done" + + name = registry.add_orchestrator(dummy_orchestrator) + retrieved = registry.get_orchestrator(name) + + assert retrieved is dummy_orchestrator + + +def test_registry_get_orchestrator_not_exists(): + """Test retrieving a non-existent orchestrator returns None.""" + registry = worker._Registry() + + retrieved = registry.get_orchestrator("non_existent") + + assert retrieved is None + + +def test_registry_get_activity_exists(): + """Test retrieving an existing activity.""" + registry = worker._Registry() + + def dummy_activity(ctx, input): + return "done" + + name = registry.add_activity(dummy_activity) + retrieved = registry.get_activity(name) + + assert retrieved is dummy_activity + + +def test_registry_get_activity_not_exists(): + """Test retrieving a non-existent activity returns None.""" + registry = worker._Registry() + + retrieved = registry.get_activity("non_existent") + + assert retrieved is None + + +def test_registry_add_multiple_orchestrators(): + """Test adding multiple different orchestrators.""" + registry = worker._Registry() + + def orchestrator1(ctx, input): + return "one" + + def orchestrator2(ctx, input): + return "two" + + name1 = registry.add_orchestrator(orchestrator1) + name2 = registry.add_orchestrator(orchestrator2) + + assert name1 != name2 + assert registry.get_orchestrator(name1) is orchestrator1 + assert registry.get_orchestrator(name2) is orchestrator2 + + +def test_registry_add_multiple_activities(): + """Test adding multiple different activities.""" + registry = worker._Registry() + + def activity1(ctx, input): + return "one" + + def activity2(ctx, input): + return "two" + + name1 = registry.add_activity(activity1) + name2 = registry.add_activity(activity2) + + assert name1 != name2 + assert registry.get_activity(name1) is activity1 + assert registry.get_activity(name2) is activity2 + diff --git a/tests/durabletask/test_worker_grpc_errors.py b/tests/durabletask/test_worker_grpc_errors.py new file mode 100644 index 0000000..52a334c --- /dev/null +++ b/tests/durabletask/test_worker_grpc_errors.py @@ -0,0 +1,114 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from unittest.mock import MagicMock, Mock, patch + +import grpc + +from durabletask import worker + + +def test_execute_orchestrator_grpc_error_benign_cancelled(): + """Test that benign gRPC errors in orchestrator execution are handled gracefully.""" + w = worker.TaskHubGrpcWorker() + + # Add a dummy orchestrator + def test_orchestrator(ctx, input): + return "result" + + w.add_orchestrator(test_orchestrator) + + # Mock the stub to raise a benign error + mock_stub = MagicMock() + mock_error = grpc.RpcError() + mock_error.code = Mock(return_value=grpc.StatusCode.CANCELLED) + mock_stub.CompleteOrchestratorTask.side_effect = mock_error + + # Create a mock request with proper structure + mock_req = MagicMock() + mock_req.instanceId = "test-id" + mock_req.pastEvents = [] + mock_req.newEvents = [MagicMock()] + mock_req.newEvents[0].HasField = lambda x: x == "executionStarted" + mock_req.newEvents[0].executionStarted.name = "test_orchestrator" + mock_req.newEvents[0].executionStarted.input = None + mock_req.newEvents[0].router.targetAppID = None + mock_req.newEvents[0].router.sourceAppID = None + mock_req.newEvents[0].timestamp.ToDatetime = Mock(return_value=None) + + # Should not raise exception (benign error) + w._execute_orchestrator(mock_req, mock_stub, "token") + + +def test_execute_orchestrator_grpc_error_non_benign(): + """Test that non-benign gRPC errors in orchestrator execution are logged.""" + w = worker.TaskHubGrpcWorker() + + # Add a dummy orchestrator + def test_orchestrator(ctx, input): + return "result" + + w.add_orchestrator(test_orchestrator) + + # Mock the stub to raise a non-benign error + mock_stub = MagicMock() + mock_error = grpc.RpcError() + mock_error.code = Mock(return_value=grpc.StatusCode.INTERNAL) + mock_stub.CompleteOrchestratorTask.side_effect = mock_error + + # Create a mock request with proper structure + mock_req = MagicMock() + mock_req.instanceId = "test-id" + mock_req.pastEvents = [] + mock_req.newEvents = [MagicMock()] + mock_req.newEvents[0].HasField = lambda x: x == "executionStarted" + mock_req.newEvents[0].executionStarted.name = "test_orchestrator" + mock_req.newEvents[0].executionStarted.input = None + mock_req.newEvents[0].router.targetAppID = None + mock_req.newEvents[0].router.sourceAppID = None + mock_req.newEvents[0].timestamp.ToDatetime = Mock(return_value=None) + + # Should not raise exception (error is logged but handled) + with patch.object(w._logger, "exception") as mock_log: + w._execute_orchestrator(mock_req, mock_stub, "token") + # Verify error was logged + assert mock_log.called + + +def test_execute_activity_grpc_error_benign(): + """Test that benign gRPC errors in activity execution are handled gracefully.""" + w = worker.TaskHubGrpcWorker() + + # Add a dummy activity + def test_activity(ctx, input): + return "result" + + w.add_activity(test_activity) + + # Mock the stub to raise a benign error + mock_stub = MagicMock() + mock_error = grpc.RpcError() + mock_error.code = Mock(return_value=grpc.StatusCode.CANCELLED) + str_return = "unknown instance ID/task ID combo" + mock_error.__str__ = Mock(return_value=str_return) + mock_stub.CompleteActivityTask.side_effect = mock_error + + # Create a mock request + mock_req = MagicMock() + mock_req.orchestrationInstance.instanceId = "test-id" + mock_req.name = "test_activity" + mock_req.taskId = 1 + mock_req.input.value = '""' + + # Should not raise exception (benign error) + w._execute_activity(mock_req, mock_stub, "token") diff --git a/tox.ini b/tox.ini index 9b21313..b405432 100644 --- a/tox.ini +++ b/tox.ini @@ -10,17 +10,19 @@ runner = virtualenv [testenv] # you can run tox with the e2e pytest marker using tox factors: -# tox -e py310,py311,py312,py313,py314 -- e2e -# or single one with: +# # start dapr sidecar (better than durabletask-go for multi-app executions) +# dapr init # maybe not needed if already done +# dapr run --app-id test-app --dapr-grpc-port 4001 --resources-path ./examples/components/ +# # In a separate terminal, run e2e tests (appends to .coverage) # tox -e py310-e2e -# to use custom grpc endpoint and not capture print statements (-s arg in pytest): -# DAPR_GRPC_ENDPOINT=localhost:12345 tox -e py310-e2e -- -s +# to use custom grpc endpoint: +# DAPR_GRPC_ENDPOINT=localhost:12345 tox -e py310-e2e setenv = PYTHONDONTWRITEBYTECODE=1 deps = .[dev] commands = !e2e: pytest -m "not e2e" --verbose - e2e: pytest -m e2e --verbose + e2e: pytest -m e2e --verbose commands_pre = pip3 install -e {toxinidir}/ allowlist_externals = pip3