diff --git a/.github/workflows/typecheck.yml b/.github/workflows/typecheck.yml new file mode 100644 index 0000000..f10463a --- /dev/null +++ b/.github/workflows/typecheck.yml @@ -0,0 +1,38 @@ +name: Type Check (pyright) + +on: + push: + branches: + - "main" + tags: + - "v*" + - "azuremanaged-v*" + pull_request: + branches: + - "main" + +permissions: + contents: read + +jobs: + pyright: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python 3.10 (lowest supported) + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install packages and dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -e ".[azure-blob-payloads,opentelemetry]" + pip install -e ./durabletask-azuremanaged + pip install pyright + + - name: Run pyright (strict, Python 3.10) + run: pyright diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a7daf3..82bd8cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,17 @@ ADDED resiliency-aware teardown introduced in v1.5.0 (in-flight recreate thread join, retired-channel timer cancellation, and SDK-owned channel cleanup) runs unchanged through the new `with` path. +- Added a pyright type-check CI workflow that runs on pull requests and pushes + to `main`, using strict mode against the lowest supported Python version + (3.10) across both `durabletask` and `durabletask-azuremanaged` packages. +- Improved type coverage across the public API. `OrchestrationContext.create_timer` + now returns the specific `TimerTask` type (previously `CancellableTask`) + ([#93](https://github.com/microsoft/durabletask-python/issues/93)), and + `WhenAnyTask` is now generic with `when_any(tasks: Sequence[Task[T]]) -> WhenAnyTask[T]` + for better static type inference of the completing child task + ([#94](https://github.com/microsoft/durabletask-python/issues/94)). + These changes also broadly improve generic type-safety hints throughout the + SDK ([#92](https://github.com/microsoft/durabletask-python/issues/92)). ## v1.5.0 diff --git a/dev-requirements.txt b/dev-requirements.txt index 98f4c30..6ab2b92 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,2 +1,3 @@ grpcio-tools pymarkdownlnt +pyright diff --git a/durabletask-azuremanaged/CHANGELOG.md b/durabletask-azuremanaged/CHANGELOG.md index 5dbe203..3745c56 100644 --- a/durabletask-azuremanaged/CHANGELOG.md +++ b/durabletask-azuremanaged/CHANGELOG.md @@ -7,6 +7,12 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased +- Improved type coverage benefits Azure Managed users: `create_timer` now + returns the specific `TimerTask` type and `when_any` is generic so the + completing child task is type-checked through `DurableTaskSchedulerClient`, + `AsyncDurableTaskSchedulerClient`, and `DurableTaskSchedulerWorker` derived + orchestrations. + ## v1.5.0 - Updates base dependency to durabletask v1.5.0 diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py b/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py index ce6f3ee..e7cbb10 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py @@ -14,8 +14,6 @@ from durabletask.internal.grpc_interceptor import ( DefaultAsyncClientInterceptorImpl, DefaultClientInterceptorImpl, - _AsyncClientCallDetails, - _ClientCallDetails, ) @@ -62,7 +60,7 @@ def _upsert_authorization_header(self, token: str) -> None: self._metadata.append(("authorization", f"Bearer {token}")) def _intercept_call( - self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: + self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails: """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC call details.""" # Refresh the auth token if a credential was provided. The call to @@ -114,7 +112,7 @@ def _upsert_authorization_header(self, token: str) -> None: self._metadata.append(("authorization", f"Bearer {token}")) async def _intercept_call( - self, client_call_details: _AsyncClientCallDetails) -> grpc.aio.ClientCallDetails: + self, client_call_details: grpc.aio.ClientCallDetails) -> grpc.aio.ClientCallDetails: """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC call details.""" # Refresh the auth token if a credential was provided. The call to diff --git a/durabletask/client.py b/durabletask/client.py index 7c85f9f..47e711c 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -6,13 +6,15 @@ import threading import time import uuid +from collections.abc import AsyncIterable, Iterable, Sequence from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Generic, Sequence, TypeVar +from typing import Any, Generic, Protocol, TypeVar, cast import grpc import grpc.aio +from google.protobuf import wrappers_pb2 import durabletask.history as history from durabletask.entities import EntityInstanceId @@ -64,8 +66,8 @@ class OrchestrationStatus(Enum): PENDING = pb.ORCHESTRATION_STATUS_PENDING SUSPENDED = pb.ORCHESTRATION_STATUS_SUSPENDED - def __str__(self): - return helpers.get_orchestration_status_str(self.value) + def __str__(self) -> str: + return cast(str, helpers.get_orchestration_status_str(self.value)) @dataclass @@ -173,6 +175,128 @@ def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationStat _RETIRED_CHANNEL_CLOSE_DELAY_SECONDS = 30.0 +class _SyncTaskHubSidecarServiceStub(Protocol): + def StartInstance(self, request: pb.CreateInstanceRequest) -> pb.CreateInstanceResponse: + ... + + def GetInstance(self, request: pb.GetInstanceRequest) -> pb.GetInstanceResponse: + ... + + def StreamInstanceHistory(self, request: pb.StreamInstanceHistoryRequest) -> Iterable[pb.HistoryChunk]: + ... + + def ListInstanceIds(self, request: pb.ListInstanceIdsRequest) -> pb.ListInstanceIdsResponse: + ... + + def QueryInstances(self, request: pb.QueryInstancesRequest) -> pb.QueryInstancesResponse: + ... + + def WaitForInstanceStart( + self, + request: pb.GetInstanceRequest, + *, + timeout: float | None = None) -> pb.GetInstanceResponse: + ... + + def WaitForInstanceCompletion( + self, + request: pb.GetInstanceRequest, + *, + timeout: float | None = None) -> pb.GetInstanceResponse: + ... + + def RaiseEvent(self, request: pb.RaiseEventRequest) -> pb.RaiseEventResponse: + ... + + def TerminateInstance(self, request: pb.TerminateRequest) -> pb.TerminateResponse: + ... + + def SuspendInstance(self, request: pb.SuspendRequest) -> pb.SuspendResponse: + ... + + def ResumeInstance(self, request: pb.ResumeRequest) -> pb.ResumeResponse: + ... + + def RestartInstance(self, request: pb.RestartInstanceRequest) -> pb.RestartInstanceResponse: + ... + + def PurgeInstances(self, request: pb.PurgeInstancesRequest) -> pb.PurgeInstancesResponse: + ... + + def SignalEntity(self, request: pb.SignalEntityRequest) -> pb.SignalEntityResponse: + ... + + def GetEntity(self, request: pb.GetEntityRequest) -> pb.GetEntityResponse: + ... + + def QueryEntities(self, request: pb.QueryEntitiesRequest) -> pb.QueryEntitiesResponse: + ... + + def CleanEntityStorage(self, request: pb.CleanEntityStorageRequest) -> pb.CleanEntityStorageResponse: + ... + + +class _AsyncTaskHubSidecarServiceStub(Protocol): + async def StartInstance(self, request: pb.CreateInstanceRequest) -> pb.CreateInstanceResponse: + ... + + async def GetInstance(self, request: pb.GetInstanceRequest) -> pb.GetInstanceResponse: + ... + + def StreamInstanceHistory(self, request: pb.StreamInstanceHistoryRequest) -> AsyncIterable[pb.HistoryChunk]: + ... + + async def ListInstanceIds(self, request: pb.ListInstanceIdsRequest) -> pb.ListInstanceIdsResponse: + ... + + async def QueryInstances(self, request: pb.QueryInstancesRequest) -> pb.QueryInstancesResponse: + ... + + async def WaitForInstanceStart( + self, + request: pb.GetInstanceRequest, + *, + timeout: float | None = None) -> pb.GetInstanceResponse: + ... + + async def WaitForInstanceCompletion( + self, + request: pb.GetInstanceRequest, + *, + timeout: float | None = None) -> pb.GetInstanceResponse: + ... + + async def RaiseEvent(self, request: pb.RaiseEventRequest) -> pb.RaiseEventResponse: + ... + + async def TerminateInstance(self, request: pb.TerminateRequest) -> pb.TerminateResponse: + ... + + async def SuspendInstance(self, request: pb.SuspendRequest) -> pb.SuspendResponse: + ... + + async def ResumeInstance(self, request: pb.ResumeRequest) -> pb.ResumeResponse: + ... + + async def RestartInstance(self, request: pb.RestartInstanceRequest) -> pb.RestartInstanceResponse: + ... + + async def PurgeInstances(self, request: pb.PurgeInstancesRequest) -> pb.PurgeInstancesResponse: + ... + + async def SignalEntity(self, request: pb.SignalEntityRequest) -> pb.SignalEntityResponse: + ... + + async def GetEntity(self, request: pb.GetEntityRequest) -> pb.GetEntityResponse: + ... + + async def QueryEntities(self, request: pb.QueryEntitiesRequest) -> pb.QueryEntitiesResponse: + ... + + async def CleanEntityStorage(self, request: pb.CleanEntityStorageRequest) -> pb.CleanEntityStorageResponse: + ... + + class TaskHubGrpcClient: def __init__(self, *, host_address: str | None = None, @@ -245,7 +369,7 @@ def __init__(self, *, # observable effect. Callers wanting resiliency on a custom channel # can prepend the interceptor themselves via grpc.intercept_channel. self._channel = channel - self._stub = stubs.TaskHubSidecarServiceStub(channel) + self._stub = cast(_SyncTaskHubSidecarServiceStub, stubs.TaskHubSidecarServiceStub(channel)) self._logger = shared.get_logger("client", log_handler, log_formatter) self.default_version = default_version self._payload_store = payload_store @@ -322,7 +446,7 @@ def _maybe_recreate_channel(self) -> None: interceptors=self._interceptors, channel_options=self._channel_options, ) - self._stub = stubs.TaskHubSidecarServiceStub(self._channel) + self._stub = cast(_SyncTaskHubSidecarServiceStub, stubs.TaskHubSidecarServiceStub(self._channel)) self._last_recreate_time = now self._client_failure_tracker.record_success() close_timer = threading.Timer( @@ -459,11 +583,11 @@ def get_all_orchestration_states(self, ) -> list[OrchestrationState]: if orchestration_query is None: orchestration_query = OrchestrationQuery() - _continuation_token = None + _continuation_token: wrappers_pb2.StringValue | None = None self._logger.info(f"Querying orchestration instances with query: {orchestration_query}") - states = [] + states: list[OrchestrationState] = [] while True: req = build_query_instances_req(orchestration_query, _continuation_token) @@ -621,11 +745,11 @@ def get_all_entities(self, entity_query: EntityQuery | None = None) -> list[EntityMetadata]: if entity_query is None: entity_query = EntityQuery() - _continuation_token = None + _continuation_token: wrappers_pb2.StringValue | None = None self._logger.info(f"Retrieving entities by filter: {entity_query}") - entities = [] + entities: list[EntityMetadata] = [] while True: query_request = build_query_entities_req(entity_query, _continuation_token) @@ -647,7 +771,7 @@ def clean_entity_storage(self, empty_entities_removed = 0 orphaned_locks_released = 0 - _continuation_token = None + _continuation_token: wrappers_pb2.StringValue | None = None while True: req = pb.CleanEntityStorageRequest( @@ -741,7 +865,7 @@ def __init__(self, *, # leave the failure-tracking opt-out implicit: callers wanting full # resiliency should let us create the channel. self._channel = channel - self._stub = stubs.TaskHubSidecarServiceStub(channel) + self._stub = cast(_AsyncTaskHubSidecarServiceStub, stubs.TaskHubSidecarServiceStub(channel)) self._logger = shared.get_logger("async_client", log_handler, log_formatter) self.default_version = default_version self._payload_store = payload_store @@ -839,7 +963,7 @@ async def _maybe_recreate_channel(self) -> None: interceptors=self._interceptors, channel_options=self._channel_options, ) - self._stub = stubs.TaskHubSidecarServiceStub(self._channel) + self._stub = cast(_AsyncTaskHubSidecarServiceStub, stubs.TaskHubSidecarServiceStub(self._channel)) self._last_recreate_time = now self._client_failure_tracker.record_success() self._retired_channels.append(old_channel) @@ -940,11 +1064,11 @@ async def get_all_orchestration_states(self, ) -> list[OrchestrationState]: if orchestration_query is None: orchestration_query = OrchestrationQuery() - _continuation_token = None + _continuation_token: wrappers_pb2.StringValue | None = None self._logger.info(f"Querying orchestration instances with query: {orchestration_query}") - states = [] + states: list[OrchestrationState] = [] while True: req = build_query_instances_req(orchestration_query, _continuation_token) @@ -1101,11 +1225,11 @@ async def get_all_entities(self, entity_query: EntityQuery | None = None) -> list[EntityMetadata]: if entity_query is None: entity_query = EntityQuery() - _continuation_token = None + _continuation_token: wrappers_pb2.StringValue | None = None self._logger.info(f"Retrieving entities by filter: {entity_query}") - entities = [] + entities: list[EntityMetadata] = [] while True: query_request = build_query_entities_req(entity_query, _continuation_token) @@ -1127,7 +1251,7 @@ async def clean_entity_storage(self, empty_entities_removed = 0 orphaned_locks_released = 0 - _continuation_token = None + _continuation_token: wrappers_pb2.StringValue | None = None while True: req = pb.CleanEntityStorageRequest( diff --git a/durabletask/entities/entity_context.py b/durabletask/entities/entity_context.py index d7e57e4..ba2e18e 100644 --- a/durabletask/entities/entity_context.py +++ b/durabletask/entities/entity_context.py @@ -71,7 +71,7 @@ def get_state(self, intended_type: type[TState] | None = None, default: TState | """ return self._state.get_state(intended_type, default) - def set_state(self, new_state: Any): + def set_state(self, new_state: Any) -> None: """Set the state of the entity to a new value. Parameters @@ -93,7 +93,7 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, in input : Any, optional The input to provide to the entity for the operation. """ - encoded_input = shared.to_json(input) if input is not None else None + encoded_input: str | None = shared.to_json(input) if input is not None else None self._state.add_operation_action( pb.OperationAction( sendSignal=pb.SendSignalAction( @@ -124,7 +124,7 @@ def schedule_new_orchestration(self, orchestration_name: str, input: Any | None str The instance ID of the scheduled orchestration. """ - encoded_input = shared.to_json(input) if input is not None else None + encoded_input: str | None = shared.to_json(input) if input is not None else None if not instance_id: instance_id = uuid.uuid4().hex self._state.add_operation_action( diff --git a/durabletask/entities/entity_instance_id.py b/durabletask/entities/entity_instance_id.py index d8fc430..dc2340d 100644 --- a/durabletask/entities/entity_instance_id.py +++ b/durabletask/entities/entity_instance_id.py @@ -8,14 +8,14 @@ def __init__(self, entity: str, key: str): def __str__(self) -> str: return f"@{self.entity}@{self.key}" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, EntityInstanceId): return False return self.entity == other.entity and self.key == other.key - def __lt__(self, other): + def __lt__(self, other: object) -> bool: if not isinstance(other, EntityInstanceId): - return self < other + return NotImplemented return str(self) < str(other) @staticmethod diff --git a/durabletask/entities/entity_lock.py b/durabletask/entities/entity_lock.py index 33c52d8..0b23701 100644 --- a/durabletask/entities/entity_lock.py +++ b/durabletask/entities/entity_lock.py @@ -16,4 +16,4 @@ def __enter__(self) -> EntityLock: return self def __exit__(self, *args: object) -> None: - self._context._exit_critical_section() + self._context._exit_critical_section() # pyright: ignore[reportPrivateUsage] diff --git a/durabletask/extensions/azure_blob_payloads/__init__.py b/durabletask/extensions/azure_blob_payloads/__init__.py index 83ff0dd..3dd38da 100644 --- a/durabletask/extensions/azure_blob_payloads/__init__.py +++ b/durabletask/extensions/azure_blob_payloads/__init__.py @@ -25,7 +25,7 @@ """ try: - from azure.storage.blob import BlobServiceClient # noqa: F401 + from azure.storage.blob import BlobServiceClient # noqa: F401 # pyright: ignore[reportUnusedImport] except ImportError as exc: raise ImportError( "The 'azure-storage-blob' package is required for blob payload " diff --git a/durabletask/extensions/azure_blob_payloads/blob_payload_store.py b/durabletask/extensions/azure_blob_payloads/blob_payload_store.py index 8a1b1c7..be3839c 100644 --- a/durabletask/extensions/azure_blob_payloads/blob_payload_store.py +++ b/durabletask/extensions/azure_blob_payloads/blob_payload_store.py @@ -10,8 +10,11 @@ import uuid from azure.core.exceptions import ResourceExistsError -from azure.storage.blob import BlobServiceClient -from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient +from azure.storage.blob import BlobServiceClient, ContainerClient +from azure.storage.blob.aio import ( + BlobServiceClient as AsyncBlobServiceClient, + ContainerClient as AsyncContainerClient, +) from durabletask.extensions.azure_blob_payloads.options import BlobPayloadStoreOptions from durabletask.payload.store import PayloadStore @@ -40,47 +43,68 @@ class BlobPayloadStore(PayloadStore): options: A :class:`BlobPayloadStoreOptions` with all settings. """ - def __init__(self, options: BlobPayloadStoreOptions): + def __init__(self, options: BlobPayloadStoreOptions) -> None: if not options.connection_string and not options.account_url: raise ValueError( "Either 'connection_string' or 'account_url' (with 'credential') must be provided." ) - self._options = options - self._container_name = options.container_name - - # Optional kwargs shared by both sync and async clients. - extra_kwargs: dict = {} - if options.api_version: - extra_kwargs["api_version"] = options.api_version + self._options: BlobPayloadStoreOptions = options + self._container_name: str = options.container_name + self._blob_service_client: BlobServiceClient + self._async_blob_service_client: AsyncBlobServiceClient # Build sync client if options.connection_string: - self._blob_service_client = BlobServiceClient.from_connection_string( - options.connection_string, **extra_kwargs, - ) + if options.api_version: + self._blob_service_client = BlobServiceClient.from_connection_string( + options.connection_string, + api_version=options.api_version, + ) + else: + self._blob_service_client = BlobServiceClient.from_connection_string( + options.connection_string, + ) else: assert options.account_url is not None # guaranteed by validation above - self._blob_service_client = BlobServiceClient( - account_url=options.account_url, - credential=options.credential, - **extra_kwargs, - ) + if options.api_version: + self._blob_service_client = BlobServiceClient( + account_url=options.account_url, + credential=options.credential, + api_version=options.api_version, + ) + else: + self._blob_service_client = BlobServiceClient( + account_url=options.account_url, + credential=options.credential, + ) # Build async client if options.connection_string: - self._async_blob_service_client = AsyncBlobServiceClient.from_connection_string( - options.connection_string, **extra_kwargs, - ) + if options.api_version: + self._async_blob_service_client = AsyncBlobServiceClient.from_connection_string( + options.connection_string, + api_version=options.api_version, + ) + else: + self._async_blob_service_client = AsyncBlobServiceClient.from_connection_string( + options.connection_string, + ) else: assert options.account_url is not None # guaranteed by validation above - self._async_blob_service_client = AsyncBlobServiceClient( - account_url=options.account_url, - credential=options.credential, - **extra_kwargs, - ) - - self._ensure_container_created = False + if options.api_version: + self._async_blob_service_client = AsyncBlobServiceClient( + account_url=options.account_url, + credential=options.credential, + api_version=options.api_version, + ) + else: + self._async_blob_service_client = AsyncBlobServiceClient( + account_url=options.account_url, + credential=options.credential, + ) + + self._ensure_container_created: bool = False # ------------------------------------------------------------------ # Lifecycle / resource management @@ -121,7 +145,9 @@ def upload(self, data: bytes, *, instance_id: str | None = None) -> str: data = gzip.compress(data) blob_name = self._make_blob_name(instance_id) - container_client = self._blob_service_client.get_container_client(self._container_name) + container_client: ContainerClient = self._blob_service_client.get_container_client( + self._container_name + ) container_client.upload_blob(name=blob_name, data=data, overwrite=True) token = f"{_TOKEN_PREFIX}{self._container_name}:{blob_name}" @@ -130,7 +156,7 @@ def upload(self, data: bytes, *, instance_id: str | None = None) -> str: def download(self, token: str) -> bytes: container, blob_name = self._parse_token(token) - container_client = self._blob_service_client.get_container_client(container) + container_client: ContainerClient = self._blob_service_client.get_container_client(container) blob_data = container_client.download_blob(blob_name).readall() if self._options.enable_compression: @@ -150,7 +176,9 @@ async def upload_async(self, data: bytes, *, instance_id: str | None = None) -> data = gzip.compress(data) blob_name = self._make_blob_name(instance_id) - container_client = self._async_blob_service_client.get_container_client(self._container_name) + container_client: AsyncContainerClient = self._async_blob_service_client.get_container_client( + self._container_name + ) await container_client.upload_blob(name=blob_name, data=data, overwrite=True) token = f"{_TOKEN_PREFIX}{self._container_name}:{blob_name}" @@ -159,7 +187,9 @@ async def upload_async(self, data: bytes, *, instance_id: str | None = None) -> async def download_async(self, token: str) -> bytes: container, blob_name = self._parse_token(token) - container_client = self._async_blob_service_client.get_container_client(container) + container_client: AsyncContainerClient = self._async_blob_service_client.get_container_client( + container + ) stream = await container_client.download_blob(blob_name) blob_data = await stream.readall() @@ -206,7 +236,9 @@ def _make_blob_name(instance_id: str | None = None) -> str: def _ensure_container_sync(self) -> None: if self._ensure_container_created: return - container_client = self._blob_service_client.get_container_client(self._container_name) + container_client: ContainerClient = self._blob_service_client.get_container_client( + self._container_name + ) try: container_client.create_container() except ResourceExistsError: @@ -216,7 +248,9 @@ def _ensure_container_sync(self) -> None: async def _ensure_container_async(self) -> None: if self._ensure_container_created: return - container_client = self._async_blob_service_client.get_container_client(self._container_name) + container_client: AsyncContainerClient = self._async_blob_service_client.get_container_client( + self._container_name + ) try: await container_client.create_container() except ResourceExistsError: diff --git a/durabletask/grpc_options.py b/durabletask/grpc_options.py index c3efc8d..02cf766 100644 --- a/durabletask/grpc_options.py +++ b/durabletask/grpc_options.py @@ -74,7 +74,7 @@ class GrpcChannelOptions: keepalive_timeout_ms: int | None = None keepalive_permit_without_calls: bool | None = None retry_policy: GrpcRetryPolicyOptions | None = None - raw_options: list[tuple[str, Any]] = field(default_factory=list) + raw_options: list[tuple[str, Any]] = field(default_factory=lambda: []) def to_grpc_options(self) -> list[tuple[str, Any]]: options = list(self.raw_options) diff --git a/durabletask/history.py b/durabletask/history.py index 7e240f2..c72a376 100644 --- a/durabletask/history.py +++ b/durabletask/history.py @@ -3,9 +3,10 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import asdict, dataclass from datetime import datetime, timezone -from typing import Any +from typing import Any, cast from google.protobuf import json_format from google.protobuf.message import Message @@ -237,7 +238,7 @@ class ExecutionRewoundEvent(HistoryEvent): tags: dict[str, str] | None = None -def _from_protobuf(event: pb.HistoryEvent) -> HistoryEvent: +def _from_protobuf(event: pb.HistoryEvent) -> HistoryEvent: # pyright: ignore[reportUnusedFunction] event_type = event.WhichOneof('eventType') if event_type is None: raise ValueError('History event does not have an eventType set') @@ -329,13 +330,16 @@ def _to_serializable(value: Any) -> Any: if isinstance(value, datetime): return value.isoformat() if isinstance(value, list): - return [_to_serializable(item) for item in value] + return [_to_serializable(item) for item in cast(list[Any], value)] if isinstance(value, dict): - return {key: _to_serializable(item) for key, item in value.items()} + return { + key: _to_serializable(item) + for key, item in cast(dict[Any, Any], value).items() + } return value -_EVENT_CONVERTERS: dict[str, Any] = { +_EVENT_CONVERTERS: dict[str, Callable[[pb.HistoryEvent], HistoryEvent]] = { 'executionStarted': lambda event: ExecutionStartedEvent( **_base_kwargs(event), name=event.executionStarted.name, diff --git a/durabletask/internal/client_helpers.py b/durabletask/internal/client_helpers.py index 593ee9d..ef27c50 100644 --- a/durabletask/internal/client_helpers.py +++ b/durabletask/internal/client_helpers.py @@ -5,8 +5,11 @@ import logging import uuid +from collections.abc import Sequence from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Sequence, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar + +from google.protobuf import wrappers_pb2 import durabletask.internal.helpers as helpers import durabletask.internal.orchestrator_service_pb2 as pb @@ -83,7 +86,7 @@ def build_schedule_new_orchestration_req( def build_query_instances_req( orchestration_query: OrchestrationQuery, - continuation_token) -> pb.QueryInstancesRequest: + continuation_token: wrappers_pb2.StringValue | None) -> pb.QueryInstancesRequest: """Build a QueryInstancesRequest from an OrchestrationQuery.""" return pb.QueryInstancesRequest( query=pb.InstanceQuery( @@ -115,7 +118,7 @@ def build_purge_by_filter_req( def build_query_entities_req( entity_query: EntityQuery, - continuation_token) -> pb.QueryEntitiesRequest: + continuation_token: wrappers_pb2.StringValue | None) -> pb.QueryEntitiesRequest: """Build a QueryEntitiesRequest from an EntityQuery.""" return pb.QueryEntitiesRequest( query=pb.EntityQuery( @@ -130,7 +133,10 @@ def build_query_entities_req( ) -def check_continuation_token(resp_token, prev_token, logger: logging.Logger) -> bool: +def check_continuation_token( + resp_token: wrappers_pb2.StringValue | None, + prev_token: wrappers_pb2.StringValue | None, + logger: logging.Logger) -> bool: """Check if a continuation token indicates more pages. Returns True to continue, False to stop.""" if resp_token and resp_token.value and resp_token.value != "0": logger.info(f"Received continuation token with value {resp_token.value}, fetching next page...") @@ -144,7 +150,7 @@ def check_continuation_token(resp_token, prev_token, logger: logging.Logger) -> def log_completion_state( logger: logging.Logger, instance_id: str, - state: OrchestrationState | None): + state: OrchestrationState | None) -> None: """Log the final state of a completed orchestration.""" if not state: return diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py index 28ac15a..130de02 100644 --- a/durabletask/internal/entity_state_shim.py +++ b/durabletask/internal/entity_state_shim.py @@ -6,7 +6,7 @@ class StateShim: - def __init__(self, start_state): + def __init__(self, start_state: Any): self._current_state: Any = start_state self._checkpoint_state: Any = start_state self._operation_actions: list[pb.OperationAction] = [] @@ -41,24 +41,24 @@ def get_state(self, intended_type: type[TState] | None = None, default: TState | f"Could not convert state of type '{type(self._current_state).__name__}' to '{intended_type.__name__}'" ) from ex - def set_state(self, state): + def set_state(self, state: Any) -> None: self._current_state = state - def add_operation_action(self, action: pb.OperationAction): + def add_operation_action(self, action: pb.OperationAction) -> None: self._operation_actions.append(action) def get_operation_actions(self) -> list[pb.OperationAction]: return self._operation_actions[:self._actions_checkpoint_state] - def commit(self): + def commit(self) -> None: self._checkpoint_state = self._current_state self._actions_checkpoint_state = len(self._operation_actions) - def rollback(self): + def rollback(self) -> None: self._current_state = self._checkpoint_state self._operation_actions = self._operation_actions[:self._actions_checkpoint_state] - def reset(self): + def reset(self) -> None: self._current_state = None self._checkpoint_state = None self._operation_actions = [] diff --git a/durabletask/internal/grpc_interceptor.py b/durabletask/internal/grpc_interceptor.py index 882c88b..ede4a9c 100644 --- a/durabletask/internal/grpc_interceptor.py +++ b/durabletask/internal/grpc_interceptor.py @@ -1,17 +1,28 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from collections import namedtuple +from collections.abc import Callable, Iterable, Sequence +from typing import Any, NamedTuple, cast import grpc import grpc.aio +_MetadataValue = str | bytes +_MetadataEntry = tuple[str, _MetadataValue] +_Metadata = Sequence[_MetadataEntry] +_MetadataLike = _Metadata | grpc.aio.Metadata -class _ClientCallDetails( - namedtuple( - '_ClientCallDetails', - ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']), - grpc.ClientCallDetails): + +class _ClientCallDetailsBase(NamedTuple): + method: Any + timeout: Any + metadata: Any + credentials: Any + wait_for_ready: Any + compression: Any + + +class _ClientCallDetails(_ClientCallDetailsBase, grpc.ClientCallDetails): """This is an implementation of the ClientCallDetails interface needed for interceptors. This class takes six named values and inherits the ClientCallDetails from grpc package. This class encloses the values that describe a RPC to be invoked. @@ -19,11 +30,15 @@ class _ClientCallDetails( pass -class _AsyncClientCallDetails( - namedtuple( - '_AsyncClientCallDetails', - ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready']), - grpc.aio.ClientCallDetails): +class _AsyncClientCallDetailsBase(NamedTuple): + method: Any + timeout: Any + metadata: Any + credentials: Any + wait_for_ready: Any + + +class _AsyncClientCallDetails(_AsyncClientCallDetailsBase, grpc.aio.ClientCallDetails): """This is an implementation of the aio ClientCallDetails interface needed for async interceptors. This class takes five named values and inherits the ClientCallDetails from grpc.aio package. This class encloses the values that describe a RPC to be invoked. @@ -31,15 +46,18 @@ class _AsyncClientCallDetails( pass -def _apply_metadata(client_call_details, metadata): +def _apply_metadata( + client_call_details: grpc.ClientCallDetails | grpc.aio.ClientCallDetails, + metadata: _Metadata | None) -> _MetadataLike | None: """Shared logic for applying metadata to call details. Returns the updated metadata list.""" + existing_metadata = cast(_MetadataLike | None, client_call_details.metadata) if metadata is None: - return client_call_details.metadata + return existing_metadata - if client_call_details.metadata is not None: - new_metadata = list(client_call_details.metadata) + if existing_metadata is not None: + new_metadata = list(cast(Iterable[_MetadataEntry], existing_metadata)) else: - new_metadata = [] + new_metadata: list[_MetadataEntry] = [] new_metadata.extend(metadata) return new_metadata @@ -68,21 +86,37 @@ def _intercept_call( client_call_details.method, client_call_details.timeout, new_metadata, client_call_details.credentials, client_call_details.wait_for_ready, client_call_details.compression) - def intercept_unary_unary(self, continuation, client_call_details, request): + def intercept_unary_unary( + self, + continuation: Callable[[grpc.ClientCallDetails, Any], Any], + client_call_details: grpc.ClientCallDetails, + request: Any) -> Any: new_client_call_details = self._intercept_call(client_call_details) return continuation(new_client_call_details, request) - def intercept_unary_stream(self, continuation, client_call_details, request): + def intercept_unary_stream( + self, + continuation: Callable[[grpc.ClientCallDetails, Any], Any], + client_call_details: grpc.ClientCallDetails, + request: Any) -> Any: new_client_call_details = self._intercept_call(client_call_details) return continuation(new_client_call_details, request) - def intercept_stream_unary(self, continuation, client_call_details, request): + def intercept_stream_unary( + self, + continuation: Callable[[grpc.ClientCallDetails, Any], Any], + client_call_details: grpc.ClientCallDetails, + request_iterator: Any) -> Any: new_client_call_details = self._intercept_call(client_call_details) - return continuation(new_client_call_details, request) + return continuation(new_client_call_details, request_iterator) - def intercept_stream_stream(self, continuation, client_call_details, request): + def intercept_stream_stream( + self, + continuation: Callable[[grpc.ClientCallDetails, Any], Any], + client_call_details: grpc.ClientCallDetails, + request_iterator: Any) -> Any: new_client_call_details = self._intercept_call(client_call_details) - return continuation(new_client_call_details, request) + return continuation(new_client_call_details, request_iterator) class DefaultAsyncClientInterceptorImpl( @@ -110,18 +144,34 @@ async def _intercept_call( client_call_details.wait_for_ready, ) - async def intercept_unary_unary(self, continuation, client_call_details, request): + async def intercept_unary_unary( + self, + continuation: Callable[[grpc.aio.ClientCallDetails, Any], Any], + client_call_details: grpc.aio.ClientCallDetails, + request: Any) -> Any: new_client_call_details = await self._intercept_call(client_call_details) return await continuation(new_client_call_details, request) - async def intercept_unary_stream(self, continuation, client_call_details, request): + async def intercept_unary_stream( + self, + continuation: Callable[[grpc.aio.ClientCallDetails, Any], Any], + client_call_details: grpc.aio.ClientCallDetails, + request: Any) -> Any: new_client_call_details = await self._intercept_call(client_call_details) return await continuation(new_client_call_details, request) - async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + async def intercept_stream_unary( + self, + continuation: Callable[[grpc.aio.ClientCallDetails, Any], Any], + client_call_details: grpc.aio.ClientCallDetails, + request_iterator: Any) -> Any: new_client_call_details = await self._intercept_call(client_call_details) return await continuation(new_client_call_details, request_iterator) - async def intercept_stream_stream(self, continuation, client_call_details, request_iterator): + async def intercept_stream_stream( + self, + continuation: Callable[[grpc.aio.ClientCallDetails, Any], Any], + client_call_details: grpc.aio.ClientCallDetails, + request_iterator: Any) -> Any: new_client_call_details = await self._intercept_call(client_call_details) return await continuation(new_client_call_details, request_iterator) diff --git a/durabletask/internal/grpc_resiliency.py b/durabletask/internal/grpc_resiliency.py index f363872..04dda4f 100644 --- a/durabletask/internal/grpc_resiliency.py +++ b/durabletask/internal/grpc_resiliency.py @@ -3,8 +3,9 @@ import random import threading +from collections.abc import Callable from dataclasses import dataclass, field -from typing import Callable +from typing import Any import grpc import grpc.aio @@ -108,7 +109,11 @@ def __init__( self._failure_tracker = failure_tracker self._on_recreate = on_recreate - def intercept_unary_unary(self, continuation, client_call_details, request): + def intercept_unary_unary( + self, + continuation: Callable[[grpc.ClientCallDetails, Any], Any], + client_call_details: grpc.ClientCallDetails, + request: Any) -> Any: response = continuation(client_call_details, request) error = response.exception() self._record_outcome(client_call_details.method, error) @@ -150,7 +155,11 @@ def __init__( self._failure_tracker = failure_tracker self._on_recreate = on_recreate - async def intercept_unary_unary(self, continuation, client_call_details, request): + async def intercept_unary_unary( + self, + continuation: Callable[[grpc.aio.ClientCallDetails, Any], Any], + client_call_details: grpc.aio.ClientCallDetails, + request: Any) -> Any: try: response = await continuation(client_call_details, request) except Exception as exc: diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 308dc66..2342afd 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -316,7 +316,7 @@ def new_create_sub_orchestration_action( )) -def is_empty(v: wrappers_pb2.StringValue): +def is_empty(v: wrappers_pb2.StringValue | None) -> bool: return v is None or v.value == '' diff --git a/durabletask/internal/history_helpers.py b/durabletask/internal/history_helpers.py index defd31b..53dd8b8 100644 --- a/durabletask/internal/history_helpers.py +++ b/durabletask/internal/history_helpers.py @@ -3,7 +3,8 @@ from __future__ import annotations -from typing import AsyncIterable, Iterable +from collections.abc import AsyncIterable, Iterable +from typing import Any import durabletask.history as history import durabletask.internal.orchestrator_service_pb2 as pb @@ -31,7 +32,7 @@ async def collect_history_events_async( return events -def history_event_to_dict(event: history.HistoryEvent) -> dict: +def history_event_to_dict(event: history.HistoryEvent) -> dict[str, Any]: return history.to_dict(event) @@ -48,7 +49,7 @@ def _clone_and_convert_events( event = pb.HistoryEvent() event.CopyFrom(source_event) payload_helpers.deexternalize_payloads(event, payload_store) - events.append(history._from_protobuf(event)) + events.append(history._from_protobuf(event)) # pyright: ignore[reportPrivateUsage] return events @@ -64,5 +65,5 @@ async def _clone_and_convert_events_async( event = pb.HistoryEvent() event.CopyFrom(source_event) await payload_helpers.deexternalize_payloads_async(event, payload_store) - events.append(history._from_protobuf(event)) + events.append(history._from_protobuf(event)) # pyright: ignore[reportPrivateUsage] return events diff --git a/durabletask/internal/orchestration_entity_context.py b/durabletask/internal/orchestration_entity_context.py index e1cb178..f4d5d1d 100644 --- a/durabletask/internal/orchestration_entity_context.py +++ b/durabletask/internal/orchestration_entity_context.py @@ -1,5 +1,6 @@ +from collections.abc import Generator from datetime import datetime -from typing import Generator +from typing import Any from durabletask.internal.helpers import get_string_value import durabletask.internal.orchestrator_service_pb2 as pb @@ -12,7 +13,7 @@ def __init__(self, instance_id: str): self.lock_acquisition_pending = False - self.critical_section_id = None + self.critical_section_id: str | None = None self.critical_section_locks: list[EntityInstanceId] = [] self.available_locks: list[EntityInstanceId] = [] @@ -57,7 +58,7 @@ def recover_lock_after_call(self, target_instance_id: EntityInstanceId): if self.is_inside_critical_section: self.available_locks.append(target_instance_id) - def emit_lock_release_messages(self): + def emit_lock_release_messages(self) -> Generator[pb.SendEntityMessageAction, None, None]: if self.is_inside_critical_section: for entity_id in self.critical_section_locks: unlock_event = pb.SendEntityMessageAction(entityUnlockSent=pb.EntityUnlockSentEvent( @@ -71,9 +72,9 @@ def emit_lock_release_messages(self): self.available_locks = [] self.critical_section_id = None - def emit_request_message(self, target, operation_name: str, one_way: bool, operation_id: str, + def emit_request_message(self, target: Any, operation_name: str, one_way: bool, operation_id: str, scheduled_time_utc: datetime, input: str | None, - request_time: datetime | None = None, create_trace: bool = False): + request_time: datetime | None = None, create_trace: bool = False) -> Any: raise NotImplementedError() def emit_acquire_message( @@ -87,7 +88,7 @@ def emit_acquire_message( # Acquire the locks in a globally fixed order to avoid deadlocks # Also remove duplicates - this can be optimized for perf if necessary entity_ids = sorted(entities) - entity_ids_dedup = [] + entity_ids_dedup: list[EntityInstanceId] = [] for i, entity_id in enumerate(entity_ids): if entity_id != entity_ids[i - 1] if i > 0 else True: entity_ids_dedup.append(entity_id) @@ -106,14 +107,14 @@ def emit_acquire_message( return request, target - def complete_acquire(self, critical_section_id): + def complete_acquire(self, critical_section_id: str) -> None: if self.critical_section_id != critical_section_id: raise RuntimeError(f"Unexpected lock acquire for critical section ID '{critical_section_id}' (expected '{self.critical_section_id}')") self.available_locks = self.critical_section_locks self.lock_acquisition_pending = False - def adjust_outgoing_message(self, instance_id: str, request_message, capped_time: datetime) -> str: + def adjust_outgoing_message(self, instance_id: str, request_message: Any, capped_time: datetime) -> str: raise NotImplementedError() - def deserialize_entity_response_event(self, event_content: str): + def deserialize_entity_response_event(self, event_content: str) -> Any: raise NotImplementedError() diff --git a/durabletask/internal/proto_task_hub_sidecar_service_stub.py b/durabletask/internal/proto_task_hub_sidecar_service_stub.py index 0f7880f..de5a4d5 100644 --- a/durabletask/internal/proto_task_hub_sidecar_service_stub.py +++ b/durabletask/internal/proto_task_hub_sidecar_service_stub.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Protocol +from collections.abc import Callable +from typing import Any, Protocol class ProtoTaskHubSidecarServiceStub(Protocol): diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index fa3d3a6..f8afc7c 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -158,46 +158,48 @@ def get_logger( return logger -def to_json(obj): +def to_json(obj: Any) -> str: return json.dumps(obj, cls=InternalJSONEncoder) -def from_json(json_str): +def from_json(json_str: str | bytes | bytearray) -> Any: return json.loads(json_str, cls=InternalJSONDecoder) class InternalJSONEncoder(json.JSONEncoder): """JSON encoder that supports serializing specific Python types.""" - def encode(self, obj: Any) -> str: + def encode(self, o: Any) -> str: # pyright: ignore[reportIncompatibleMethodOverride] # if the object is a namedtuple, convert it to a dict with the AUTO_SERIALIZED key added - if isinstance(obj, tuple) and hasattr(obj, "_fields") and hasattr(obj, "_asdict"): - d = obj._asdict() # type: ignore - d[AUTO_SERIALIZED] = True - obj = d - return super().encode(obj) - - def default(self, obj): - if dataclasses.is_dataclass(obj): + if isinstance(o, tuple): + namedtuple_obj: Any = o # pyright: ignore[reportUnknownVariableType] + if hasattr(namedtuple_obj, "_fields") and hasattr(namedtuple_obj, "_asdict"): + d: dict[str, Any] = namedtuple_obj._asdict() + d[AUTO_SERIALIZED] = True + o = d + return super().encode(o) + + def default(self, o: Any) -> Any: # pyright: ignore[reportIncompatibleMethodOverride] + if dataclasses.is_dataclass(o) and not isinstance(o, type): # Dataclasses are not serializable by default, so we convert them to a dict and mark them for # automatic deserialization by the receiver - d = dataclasses.asdict(obj) # type: ignore + d: dict[str, Any] = dataclasses.asdict(o) d[AUTO_SERIALIZED] = True return d - elif isinstance(obj, SimpleNamespace): + elif isinstance(o, SimpleNamespace): # Most commonly used for serializing custom objects that were previously serialized using our encoder - d = vars(obj) + d = vars(o) d[AUTO_SERIALIZED] = True return d # This will typically raise a TypeError - return json.JSONEncoder.default(self, obj) + return json.JSONEncoder.default(self, o) class InternalJSONDecoder(json.JSONDecoder): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(object_hook=self.dict_to_object, *args, **kwargs) - def dict_to_object(self, d: dict[str, Any]): + def dict_to_object(self, d: dict[str, Any]) -> Any: # If the object was serialized by the InternalJSONEncoder, deserialize it as a SimpleNamespace if d.pop(AUTO_SERIALIZED, False): return SimpleNamespace(**d) diff --git a/durabletask/internal/tracing.py b/durabletask/internal/tracing.py index 190d848..98ae1e9 100644 --- a/durabletask/internal/tracing.py +++ b/durabletask/internal/tracing.py @@ -19,7 +19,7 @@ import time from contextlib import contextmanager from datetime import datetime -from typing import Any +from typing import TYPE_CHECKING, Any from google.protobuf import timestamp_pb2, wrappers_pb2 @@ -30,35 +30,56 @@ # --------------------------------------------------------------------------- # Lazy / optional OpenTelemetry imports # --------------------------------------------------------------------------- -try: +# Declare the optional symbols for the type checker so the rest of the +# module can be statically type-checked even when opentelemetry is absent. +# At runtime, ``_OTEL_AVAILABLE`` gates every real use. +if TYPE_CHECKING: from opentelemetry import context as otel_context from opentelemetry import trace - from opentelemetry.trace import ( - SpanKind, # type: ignore[no-redef] - StatusCode, # type: ignore[no-redef] - ) + from opentelemetry.trace import SpanKind, StatusCode from opentelemetry.trace.propagation.tracecontext import ( TraceContextTextMapPropagator, ) - _OTEL_AVAILABLE = True +try: + from opentelemetry import context as _otel_context # noqa: F401 + from opentelemetry import trace as _trace # noqa: F401 + from opentelemetry.trace import SpanKind as _SpanKind # noqa: F401 + from opentelemetry.trace import StatusCode as _StatusCode # noqa: F401 + from opentelemetry.trace.propagation.tracecontext import ( # noqa: F401 + TraceContextTextMapPropagator as _TraceContextTextMapPropagator, + ) + + _OTEL_AVAILABLE: bool = True + # Re-bind for runtime usage so the module exposes the real classes. + otel_context = _otel_context # noqa: F811 # type: ignore[assignment] # pyright: ignore[reportConstantRedefinition] + trace = _trace # noqa: F811 # type: ignore[assignment] # pyright: ignore[reportConstantRedefinition] + SpanKind = _SpanKind # noqa: F811 # type: ignore[assignment] # pyright: ignore[reportConstantRedefinition] + StatusCode = _StatusCode # noqa: F811 # type: ignore[assignment] # pyright: ignore[reportConstantRedefinition] + TraceContextTextMapPropagator = _TraceContextTextMapPropagator # noqa: F811 # type: ignore[assignment] # pyright: ignore[reportConstantRedefinition] except ImportError: # pragma: no cover - _OTEL_AVAILABLE = False - # Provide stub for SpanKind so callers can reference tracing.SpanKind - # without guarding every reference with OTEL_AVAILABLE checks. + _OTEL_AVAILABLE = False # pyright: ignore[reportConstantRedefinition] - class SpanKind: # type: ignore[no-redef] + # Inject runtime stubs via globals so pyright (which only sees the + # TYPE_CHECKING block) doesn't flag these as incompatible reassignments. + class _SpanKindStub: INTERNAL: Any = None CLIENT: Any = None SERVER: Any = None PRODUCER: Any = None CONSUMER: Any = None - class StatusCode: # type: ignore[no-redef] + class _StatusCodeStub: OK: Any = None ERROR: Any = None UNSET: Any = None + globals()["SpanKind"] = _SpanKindStub + globals()["StatusCode"] = _StatusCodeStub + globals()["otel_context"] = None + globals()["trace"] = None + globals()["TraceContextTextMapPropagator"] = None + # Re-export so callers can check without importing opentelemetry themselves. OTEL_AVAILABLE = _OTEL_AVAILABLE @@ -479,7 +500,7 @@ def _is_deferred_span_capable() -> bool: for downstream SERVER spans instead. """ try: - from opentelemetry.sdk.trace import ReadableSpan # noqa: F401 + from opentelemetry.sdk.trace import ReadableSpan # noqa: F401 # pyright: ignore[reportUnusedImport] from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider except (ImportError, AttributeError): return False diff --git a/durabletask/payload/helpers.py b/durabletask/payload/helpers.py index fd94a2e..b42ed11 100644 --- a/durabletask/payload/helpers.py +++ b/durabletask/payload/helpers.py @@ -12,7 +12,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from google.protobuf import message as proto_message from google.protobuf import wrappers_pb2 @@ -84,7 +84,7 @@ async def deexternalize_payloads_async( # Internal recursive walkers – sync # ------------------------------------------------------------------ -def _is_map_field(fd) -> bool: +def _is_map_field(fd: Any) -> bool: """Return True if the field descriptor represents a protobuf map field.""" mt = fd.message_type return mt is not None and fd.is_repeated and mt.GetOptions().map_entry diff --git a/durabletask/task.py b/durabletask/task.py index 9b797d8..b1ae27c 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -7,8 +7,9 @@ import logging import math from abc import ABC, abstractmethod +from collections.abc import Callable, Generator, Sequence from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Generator, Generic, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, cast from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock, EntityContext import durabletask.internal.helpers as pbh @@ -99,7 +100,7 @@ def set_custom_status(self, custom_status: Any) -> None: pass @abstractmethod - def create_timer(self, fire_at: datetime | timedelta) -> CancellableTask: + def create_timer(self, fire_at: datetime | timedelta) -> TimerTask: """Create a Timer Task to fire after at the specified deadline. Parameters @@ -109,7 +110,7 @@ def create_timer(self, fire_at: datetime | timedelta) -> CancellableTask: Returns ------- - Task + TimerTask A Durable Timer Task that schedules the timer to wake up the orchestrator """ pass @@ -143,7 +144,7 @@ def call_activity(self, activity: Activity[TInput, TOutput] | str, *, def call_entity(self, entity: EntityInstanceId, operation: str, - input: TInput | None = None) -> CompletableTask[Any]: + input: Any = None) -> CompletableTask[Any]: """Schedule entity function for execution. Parameters @@ -167,7 +168,7 @@ def signal_entity( self, entity_id: EntityInstanceId, operation_name: str, - input: TInput | None = None + input: Any = None ) -> None: """Signal an entity function for execution. @@ -232,7 +233,7 @@ def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput] | st # TOOD: Add a timeout parameter, which allows the task to be cancelled if the event is # not received within the specified timeout. This requires support for task cancellation. @abstractmethod - def wait_for_external_event(self, name: str) -> CancellableTask: + def wait_for_external_event(self, name: str) -> CancellableTask[Any]: """Wait asynchronously for an event to be raised with the name `name`. Parameters @@ -242,7 +243,7 @@ def wait_for_external_event(self, name: str) -> CancellableTask: Returns ------- - Task[TOutput] + CancellableTask[Any] A Durable Task that completes when the event is received. """ pass @@ -300,7 +301,16 @@ def create_replay_safe_logger(self, logger: logging.Logger) -> ReplaySafeLogger: return ReplaySafeLogger(logger, lambda: self.is_replaying) -class ReplaySafeLogger(logging.LoggerAdapter): +if TYPE_CHECKING: + # logging.LoggerAdapter is generic in stubs but is not subscriptable + # at runtime before Python 3.11. Use a TYPE_CHECKING alias so the + # base class evaluates correctly at runtime. + _LoggerAdapterBase = logging.LoggerAdapter[logging.Logger] +else: + _LoggerAdapterBase = logging.LoggerAdapter + + +class ReplaySafeLogger(_LoggerAdapterBase): """A logger adapter that suppresses log messages during orchestration replay. This class extends :class:`logging.LoggerAdapter` and only emits log @@ -378,7 +388,7 @@ class Task(ABC, Generic[T]): """Abstract base class for asynchronous tasks in a durable orchestration.""" _result: T _exception: TaskFailedError | None - _parent: CompositeTask[T] | None + _parent: CompositeTask[Any] | None def __init__(self) -> None: super().__init__() @@ -413,9 +423,9 @@ def get_exception(self) -> TaskFailedError: class CompositeTask(Task[T]): """A task that is composed of other tasks.""" - _tasks: list[Task] + _tasks: list[Task[Any]] - def __init__(self, tasks: list[Task]): + def __init__(self, tasks: list[Task[Any]]): super().__init__() self._tasks = tasks self._completed_tasks = 0 @@ -425,11 +435,11 @@ def __init__(self, tasks: list[Task]): if task.is_complete: self.on_child_completed(task) - def get_tasks(self) -> list[Task]: + def get_tasks(self) -> list[Task[Any]]: return self._tasks @abstractmethod - def on_child_completed(self, task: Task[T]): + def on_child_completed(self, task: Task[Any]) -> None: pass @@ -437,7 +447,7 @@ class WhenAllTask(CompositeTask[list[T]]): """A task that completes when all of its child tasks complete.""" def __init__(self, tasks: list[Task[T]]): - super().__init__(tasks) + super().__init__(cast(list[Task[Any]], tasks)) self._completed_tasks = 0 self._failed_tasks = 0 @@ -446,7 +456,7 @@ def pending_tasks(self) -> int: """Returns the number of tasks that have not yet completed.""" return len(self._tasks) - self._completed_tasks - def on_child_completed(self, task: Task[T]): + def on_child_completed(self, task: Task[Any]) -> None: if self.is_complete: raise ValueError('The task has already completed.') self._completed_tasks += 1 @@ -455,7 +465,7 @@ def on_child_completed(self, task: Task[T]): self._is_complete = True if self._completed_tasks == len(self._tasks): # The order of the result MUST match the order of the tasks provided to the constructor. - self._result = [task.get_result() for task in self._tasks] + self._result = [child.get_result() for child in self._tasks] self._is_complete = True def get_completed_tasks(self) -> int: @@ -464,9 +474,9 @@ def get_completed_tasks(self) -> int: class CompletableTask(Task[T]): - def __init__(self): + def __init__(self) -> None: super().__init__() - self._retryable_parent = None + self._retryable_parent: RetryableTask[Any] | None = None def complete(self, result: T): if self._is_complete: @@ -573,7 +583,7 @@ def __init__(self, final_fire_at: datetime | None = None, self._final_fire_at = final_fire_at self._maximum_timer_interval = maximum_timer_interval - def set_retryable_parent(self, retryable_task: RetryableTask): + def set_retryable_parent(self, retryable_task: RetryableTask[Any]) -> None: self._retryable_parent = retryable_task def _handle_timer_fired(self, current_utc_datetime: datetime) -> datetime | None: @@ -585,22 +595,25 @@ def _handle_timer_fired(self, current_utc_datetime: datetime) -> datetime | None return None def _get_next_fire_at(self, current_utc_datetime: datetime) -> datetime: + # _handle_timer_fired guards both attributes before calling this method. + assert self._final_fire_at is not None + assert self._maximum_timer_interval is not None if current_utc_datetime + self._maximum_timer_interval < self._final_fire_at: return current_utc_datetime + self._maximum_timer_interval return self._final_fire_at -class WhenAnyTask(CompositeTask[Task]): +class WhenAnyTask(CompositeTask[Task[T]], Generic[T]): """A task that completes when any of its child tasks complete.""" - def __init__(self, tasks: list[Task]): - super().__init__(tasks) + def __init__(self, tasks: list[Task[T]]): + super().__init__(cast(list[Task[Any]], tasks)) - def on_child_completed(self, task: Task): + def on_child_completed(self, task: Task[Any]) -> None: # The first task to complete is the result of the WhenAnyTask. if not self.is_complete: self._is_complete = True - self._result = task + self._result = cast(Task[T], task) def when_all(tasks: list[Task[T]]) -> WhenAllTask[T]: @@ -608,9 +621,9 @@ def when_all(tasks: list[Task[T]]) -> WhenAllTask[T]: return WhenAllTask(tasks) -def when_any(tasks: list[Task]) -> WhenAnyTask: +def when_any(tasks: Sequence[Task[T]]) -> WhenAnyTask[T]: """Returns a task that completes when any of the provided tasks complete or fail.""" - return WhenAnyTask(tasks) + return WhenAnyTask(list(tasks)) class ActivityContext: @@ -723,15 +736,15 @@ def retry_timeout(self) -> timedelta | None: return self._retry_timeout -def get_entity_name(fn: Entity) -> str: +def get_entity_name(fn: Entity[Any, Any]) -> str: if hasattr(fn, "__durable_entity_name__"): return getattr(fn, "__durable_entity_name__") if isinstance(fn, type) and issubclass(fn, DurableEntity): return fn.__name__ - return get_name(fn) + return get_name(cast(Callable[..., Any], fn)) -def get_name(fn: Callable) -> str: +def get_name(fn: Callable[..., Any]) -> str: """Returns the name of the provided function""" name = fn.__name__ if name == '': diff --git a/durabletask/testing/in_memory_backend.py b/durabletask/testing/in_memory_backend.py index dd99b8d..0eb761c 100644 --- a/durabletask/testing/in_memory_backend.py +++ b/durabletask/testing/in_memory_backend.py @@ -16,9 +16,10 @@ import time import uuid from collections import deque +from collections.abc import Callable, Iterable, Iterator from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Callable +from typing import TypeAlias, cast import grpc from concurrent import futures @@ -30,6 +31,14 @@ from durabletask.entities.entity_instance_id import EntityInstanceId +_FilterMap: TypeAlias = dict[str, frozenset[str]] +_WorkItemFilter: TypeAlias = _FilterMap | None + + +def _new_history_event_list() -> list[pb.HistoryEvent]: + return [] + + @dataclass class OrchestrationInstance: """Internal orchestration instance state stored by the in-memory backend.""" @@ -44,9 +53,9 @@ class OrchestrationInstance: last_updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) completed_at: datetime | None = None failure_details: pb.TaskFailureDetails | None = None - history: list[pb.HistoryEvent] = field(default_factory=list) - pending_events: list[pb.HistoryEvent] = field(default_factory=list) - dispatched_events: list[pb.HistoryEvent] = field(default_factory=list) + history: list[pb.HistoryEvent] = field(default_factory=_new_history_event_list) + pending_events: list[pb.HistoryEvent] = field(default_factory=_new_history_event_list) + dispatched_events: list[pb.HistoryEvent] = field(default_factory=_new_history_event_list) completion_token: int = 0 tags: dict[str, str] | None = None @@ -69,8 +78,8 @@ class EntityState: serialized_state: str | None = None last_modified_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) locked_by: str | None = None - pending_operations: list[pb.HistoryEvent] = field(default_factory=list) - dispatched_operations: list[pb.HistoryEvent] = field(default_factory=list) + pending_operations: list[pb.HistoryEvent] = field(default_factory=_new_history_event_list) + dispatched_operations: list[pb.HistoryEvent] = field(default_factory=_new_history_event_list) completion_token: int = 0 @@ -153,10 +162,15 @@ def start(self) -> str: The address the server is listening on (e.g., "localhost:50051") """ self._shutdown_event.clear() - self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - stubs.add_TaskHubSidecarServiceServicer_to_server(self, self._server) - self._server.add_insecure_port(f'[::]:{self._port}') - self._server.start() + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + self._server = server + add_servicer = cast( + Callable[[InMemoryOrchestrationBackend, grpc.Server], None], + stubs.add_TaskHubSidecarServiceServicer_to_server, # pyright: ignore[reportUnknownMemberType] + ) + add_servicer(self, server) + server.add_insecure_port(f'[::]:{self._port}') + server.start() self._logger.info(f"In-memory backend started on port {self._port}") return f"localhost:{self._port}" @@ -197,11 +211,11 @@ def reset(self): # gRPC Service Methods - def Hello(self, request, context): + def Hello(self, request: empty_pb2.Empty, context: grpc.ServicerContext) -> empty_pb2.Empty: """Sends a hello request to the sidecar service.""" return empty_pb2.Empty() - def StartInstance(self, request: pb.CreateInstanceRequest, context): + def StartInstance(self, request: pb.CreateInstanceRequest, context: grpc.ServicerContext) -> pb.CreateInstanceResponse: """Starts a new orchestration instance.""" instance_id = request.instanceId if request.instanceId else uuid.uuid4().hex @@ -269,7 +283,7 @@ def StartInstance(self, request: pb.CreateInstanceRequest, context): return pb.CreateInstanceResponse(instanceId=instance_id) - def GetInstance(self, request: pb.GetInstanceRequest, context): + def GetInstance(self, request: pb.GetInstanceRequest, context: grpc.ServicerContext) -> pb.GetInstanceResponse: """Gets the status of an existing orchestration instance.""" with self._lock: instance = self._instances.get(request.instanceId) @@ -278,7 +292,7 @@ def GetInstance(self, request: pb.GetInstanceRequest, context): return self._build_instance_response(instance, request.getInputsAndOutputs) - def WaitForInstanceStart(self, request: pb.GetInstanceRequest, context): + def WaitForInstanceStart(self, request: pb.GetInstanceRequest, context: grpc.ServicerContext) -> pb.GetInstanceResponse: """Waits for an orchestration instance to reach a running or completion state.""" def predicate(inst: OrchestrationInstance) -> bool: return inst.status != pb.ORCHESTRATION_STATUS_PENDING @@ -294,7 +308,7 @@ def predicate(inst: OrchestrationInstance) -> bool: return self._build_instance_response(instance, request.getInputsAndOutputs) - def WaitForInstanceCompletion(self, request: pb.GetInstanceRequest, context): + def WaitForInstanceCompletion(self, request: pb.GetInstanceRequest, context: grpc.ServicerContext) -> pb.GetInstanceResponse: """Waits for an orchestration instance to reach a completion state.""" instance = self._wait_for_state( request.instanceId, @@ -311,7 +325,7 @@ def WaitForInstanceCompletion(self, request: pb.GetInstanceRequest, context): return self._build_instance_response(instance, request.getInputsAndOutputs) - def RaiseEvent(self, request: pb.RaiseEventRequest, context): + def RaiseEvent(self, request: pb.RaiseEventRequest, context: grpc.ServicerContext) -> pb.RaiseEventResponse: """Raises an event to a running orchestration instance.""" with self._lock: instance = self._instances.get(request.instanceId) @@ -338,7 +352,7 @@ def RaiseEvent(self, request: pb.RaiseEventRequest, context): self._logger.info(f"Raised event '{request.name}' for instance '{request.instanceId}'") return pb.RaiseEventResponse() - def TerminateInstance(self, request: pb.TerminateRequest, context): + def TerminateInstance(self, request: pb.TerminateRequest, context: grpc.ServicerContext) -> pb.TerminateResponse: """Terminates a running orchestration instance.""" with self._lock: self._terminate_instance_internal( @@ -349,7 +363,7 @@ def TerminateInstance(self, request: pb.TerminateRequest, context): return pb.TerminateResponse() - def SuspendInstance(self, request: pb.SuspendRequest, context): + def SuspendInstance(self, request: pb.SuspendRequest, context: grpc.ServicerContext) -> pb.SuspendResponse: """Suspends a running orchestration instance.""" with self._lock: instance = self._instances.get(request.instanceId) @@ -369,7 +383,7 @@ def SuspendInstance(self, request: pb.SuspendRequest, context): self._logger.info(f"Suspended instance '{request.instanceId}'") return pb.SuspendResponse() - def ResumeInstance(self, request: pb.ResumeRequest, context): + def ResumeInstance(self, request: pb.ResumeRequest, context: grpc.ServicerContext) -> pb.ResumeResponse: """Resumes a suspended orchestration instance.""" with self._lock: instance = self._instances.get(request.instanceId) @@ -386,7 +400,7 @@ def ResumeInstance(self, request: pb.ResumeRequest, context): self._logger.info(f"Resumed instance '{request.instanceId}'") return pb.ResumeResponse() - def PurgeInstances(self, request: pb.PurgeInstancesRequest, context): + def PurgeInstances(self, request: pb.PurgeInstancesRequest, context: grpc.ServicerContext) -> pb.PurgeInstancesResponse: """Purges orchestration instances from the store.""" purged_count = 0 @@ -402,7 +416,7 @@ def PurgeInstances(self, request: pb.PurgeInstancesRequest, context): elif request.HasField("purgeInstanceFilter"): # Filter-based purge pf = request.purgeInstanceFilter - to_purge = [] + to_purge: list[str] = [] for iid, inst in self._instances.items(): if not self._is_terminal_status(inst.status): continue @@ -424,7 +438,7 @@ def PurgeInstances(self, request: pb.PurgeInstancesRequest, context): isComplete=wrappers_pb2.BoolValue(value=True), ) - def RestartInstance(self, request: pb.RestartInstanceRequest, context): + def RestartInstance(self, request: pb.RestartInstanceRequest, context: grpc.ServicerContext) -> pb.RestartInstanceResponse: """Restarts a completed orchestration instance.""" with self._lock: instance = self._instances.get(request.instanceId) @@ -460,10 +474,10 @@ def RestartInstance(self, request: pb.RestartInstanceRequest, context): f"Restarted instance '{request.instanceId}' as '{new_instance_id}'") return pb.RestartInstanceResponse(instanceId=new_instance_id) - def ListInstanceIds(self, request: pb.ListInstanceIdsRequest, context): + def ListInstanceIds(self, request: pb.ListInstanceIdsRequest, context: grpc.ServicerContext) -> pb.ListInstanceIdsResponse: """Lists terminal orchestration instance IDs with completion-time pagination.""" with self._lock: - matching = [] + matching: list[OrchestrationInstance] = [] for instance in self._instances.values(): if not self._is_terminal_status(instance.status): continue @@ -478,7 +492,9 @@ def ListInstanceIds(self, request: pb.ListInstanceIdsRequest, context): matching.append(instance) matching.sort(key=lambda i: (i.completed_at, i.instance_id)) - sort_keys = [(i.completed_at, i.instance_id) for i in matching] + sort_keys: list[tuple[datetime, str]] = [ + (cast(datetime, i.completed_at), i.instance_id) for i in matching + ] start_index = 0 if request.HasField("lastInstanceKey") and request.lastInstanceKey.value: @@ -494,7 +510,8 @@ def ListInstanceIds(self, request: pb.ListInstanceIdsRequest, context): next_token = None if start_index + page_size < len(matching) and page: last = page[-1] - encoded = f"{last.completed_at.isoformat()}{_TOKEN_SEP}{last.instance_id}" + last_completed_at = cast(datetime, last.completed_at) + encoded = f"{last_completed_at.isoformat()}{_TOKEN_SEP}{last.instance_id}" next_token = wrappers_pb2.StringValue(value=encoded) return pb.ListInstanceIdsResponse( @@ -503,7 +520,7 @@ def ListInstanceIds(self, request: pb.ListInstanceIdsRequest, context): ) @staticmethod - def _parse_work_item_filters(request: pb.GetWorkItemsRequest): + def _parse_work_item_filters(request: pb.GetWorkItemsRequest) -> tuple[_WorkItemFilter, _WorkItemFilter, _WorkItemFilter]: """Extract filters from the request. Returns a tuple of three values, one per work-item category. Each @@ -517,17 +534,17 @@ def _parse_work_item_filters(request: pb.GetWorkItemsRequest): return None, None, None wf = request.workItemFilters - def _build_filter(filters): + def _build_filter(filters: Iterable[pb.OrchestrationFilter | pb.ActivityFilter]) -> _FilterMap: result: dict[str, frozenset[str]] = {} for f in filters: - versions = frozenset(f.versions) if f.versions else frozenset() - existing = result.get(f.name, frozenset()) + versions = frozenset[str](f.versions) if f.versions else frozenset[str]() + existing = result.get(f.name, frozenset[str]()) result[f.name] = existing | versions return result orch_filter = _build_filter(wf.orchestrations) activity_filter = _build_filter(wf.activities) - entity_filter = {f.name: frozenset() for f in wf.entities} + entity_filter: _FilterMap = {f.name: frozenset[str]() for f in wf.entities} return orch_filter, activity_filter, entity_filter @staticmethod @@ -549,7 +566,7 @@ def _matches_filter(name: str, version: str | None, return True # empty set -- any version return (version or "") in accepted_versions - def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): + def GetWorkItems(self, request: pb.GetWorkItemsRequest, context: grpc.ServicerContext) -> Iterator[pb.WorkItem]: """Streams work items to the worker (orchestration and activity work items).""" self._logger.info("Worker connected and requesting work items") orch_filter, activity_filter, entity_filter = self._parse_work_item_filters(request) @@ -613,8 +630,8 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): # Check for activity work if not work_item and self._activity_queue: # Scan for the first matching activity - skipped: list = [] - matched_activity = None + skipped: list[ActivityWorkItem] = [] + matched_activity: ActivityWorkItem | None = None while self._activity_queue: candidate = self._activity_queue.popleft() if not self._matches_filter( @@ -706,7 +723,7 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): except Exception: self._logger.exception("Error in GetWorkItems stream") - def CompleteOrchestratorTask(self, request: pb.OrchestratorResponse, context): + def CompleteOrchestratorTask(self, request: pb.OrchestratorResponse, context: grpc.ServicerContext) -> pb.CompleteTaskResponse: """Completes an orchestration execution with the given actions.""" with self._lock: instance = self._instances.get(request.instanceId) @@ -791,7 +808,7 @@ def CompleteOrchestratorTask(self, request: pb.OrchestratorResponse, context): return pb.CompleteTaskResponse() - def CompleteActivityTask(self, request: pb.ActivityResponse, context): + def CompleteActivityTask(self, request: pb.ActivityResponse, context: grpc.ServicerContext) -> pb.CompleteTaskResponse: """Completes an activity execution.""" with self._lock: instance = self._instances.get(request.instanceId) @@ -826,7 +843,7 @@ def CompleteActivityTask(self, request: pb.ActivityResponse, context): return pb.CompleteTaskResponse() - def CompleteEntityTask(self, request: pb.EntityBatchResult, context): + def CompleteEntityTask(self, request: pb.EntityBatchResult, context: grpc.ServicerContext) -> pb.CompleteTaskResponse: """Completes an entity batch execution.""" with self._lock: # Find entity by completion token @@ -926,7 +943,7 @@ def CompleteEntityTask(self, request: pb.EntityBatchResult, context): return pb.CompleteTaskResponse() - def SignalEntity(self, request: pb.SignalEntityRequest, context): + def SignalEntity(self, request: pb.SignalEntityRequest, context: grpc.ServicerContext) -> pb.SignalEntityResponse: """Signals an entity, queueing an operation for processing.""" with self._lock: entity_id = request.instanceId @@ -956,7 +973,7 @@ def SignalEntity(self, request: pb.SignalEntityRequest, context): self._logger.info(f"Signaled entity '{entity_id}' operation '{request.name}'") return pb.SignalEntityResponse() - def GetEntity(self, request: pb.GetEntityRequest, context): + def GetEntity(self, request: pb.GetEntityRequest, context: grpc.ServicerContext) -> pb.GetEntityResponse: """Gets entity state.""" with self._lock: entity = self._entities.get(request.instanceId) @@ -977,7 +994,7 @@ def GetEntity(self, request: pb.GetEntityRequest, context): return pb.GetEntityResponse(exists=True, entity=metadata) - def QueryInstances(self, request: pb.QueryInstancesRequest, context): + def QueryInstances(self, request: pb.QueryInstancesRequest, context: grpc.ServicerContext) -> pb.QueryInstancesResponse: """Query orchestration instances with filtering support.""" with self._lock: query = request.query @@ -988,7 +1005,7 @@ def QueryInstances(self, request: pb.QueryInstancesRequest, context): except ValueError: start_index = 0 - matching = [] + matching: list[OrchestrationInstance] = [] for instance in self._instances.values(): # Filter by runtime status if query.runtimeStatus and instance.status not in query.runtimeStatus: @@ -1011,7 +1028,7 @@ def QueryInstances(self, request: pb.QueryInstancesRequest, context): page_size = query.maxInstanceCount if query.maxInstanceCount > 0 else len(matching) page = matching[start_index:start_index + page_size] - states = [] + states: list[pb.OrchestrationState] = [] for inst in page: created_ts = timestamp_pb2.Timestamp() created_ts.FromDatetime(inst.created_at) @@ -1044,7 +1061,7 @@ def QueryInstances(self, request: pb.QueryInstancesRequest, context): continuationToken=continuation_token, ) - def QueryEntities(self, request: pb.QueryEntitiesRequest, context): + def QueryEntities(self, request: pb.QueryEntitiesRequest, context: grpc.ServicerContext) -> pb.QueryEntitiesResponse: """Query entities with filtering support.""" with self._lock: query = request.query @@ -1055,7 +1072,7 @@ def QueryEntities(self, request: pb.QueryEntitiesRequest, context): except ValueError: start_index = 0 - matching = [] + matching: list[EntityState] = [] for entity in self._entities.values(): # Filter by instance ID prefix if query.HasField("instanceIdStartsWith") and query.instanceIdStartsWith.value: @@ -1078,7 +1095,7 @@ def QueryEntities(self, request: pb.QueryEntitiesRequest, context): page_size = query.pageSize.value if query.HasField("pageSize") and query.pageSize.value > 0 else len(matching) page = matching[start_index:start_index + page_size] - entities = [] + entities: list[pb.EntityMetadata] = [] for ent in page: last_modified_ts = timestamp_pb2.Timestamp() last_modified_ts.FromDatetime(ent.last_modified_at) @@ -1105,7 +1122,7 @@ def QueryEntities(self, request: pb.QueryEntitiesRequest, context): continuationToken=continuation_token, ) - def CleanEntityStorage(self, request: pb.CleanEntityStorageRequest, context): + def CleanEntityStorage(self, request: pb.CleanEntityStorageRequest, context: grpc.ServicerContext) -> pb.CleanEntityStorageResponse: """Clean entity storage: remove empty entities and release orphaned locks.""" empty_removed = 0 locks_released = 0 @@ -1132,7 +1149,7 @@ def CleanEntityStorage(self, request: pb.CleanEntityStorageRequest, context): orphanedLocksReleased=locks_released, ) - def StreamInstanceHistory(self, request: pb.StreamInstanceHistoryRequest, context): + def StreamInstanceHistory(self, request: pb.StreamInstanceHistoryRequest, context: grpc.ServicerContext) -> Iterator[pb.HistoryChunk]: """Streams orchestration history for an instance.""" with self._lock: instance = self._instances.get(request.instanceId) @@ -1146,23 +1163,23 @@ def StreamInstanceHistory(self, request: pb.StreamInstanceHistoryRequest, contex for offset in range(0, len(history), chunk_size): yield pb.HistoryChunk(events=history[offset:offset + chunk_size]) - def CreateTaskHub(self, request: pb.CreateTaskHubRequest, context): + def CreateTaskHub(self, request: pb.CreateTaskHubRequest, context: grpc.ServicerContext) -> pb.CreateTaskHubResponse: """Creates task hub resources (no-op for in-memory).""" return pb.CreateTaskHubResponse() - def DeleteTaskHub(self, request: pb.DeleteTaskHubRequest, context): + def DeleteTaskHub(self, request: pb.DeleteTaskHubRequest, context: grpc.ServicerContext) -> pb.DeleteTaskHubResponse: """Deletes task hub resources (no-op for in-memory).""" return pb.DeleteTaskHubResponse() - def RewindInstance(self, request: pb.RewindInstanceRequest, context): + def RewindInstance(self, request: pb.RewindInstanceRequest, context: grpc.ServicerContext) -> pb.RewindInstanceResponse: """Rewinds an orchestration instance (not implemented).""" context.abort(grpc.StatusCode.UNIMPLEMENTED, "RewindInstance not implemented") - def AbandonTaskActivityWorkItem(self, request: pb.AbandonActivityTaskRequest, context): + def AbandonTaskActivityWorkItem(self, request: pb.AbandonActivityTaskRequest, context: grpc.ServicerContext) -> pb.AbandonActivityTaskResponse: """Abandons an activity work item.""" return pb.AbandonActivityTaskResponse() - def AbandonTaskOrchestratorWorkItem(self, request: pb.AbandonOrchestrationTaskRequest, context): + def AbandonTaskOrchestratorWorkItem(self, request: pb.AbandonOrchestrationTaskRequest, context: grpc.ServicerContext) -> pb.AbandonOrchestrationTaskResponse: """Abandons an orchestration work item, restoring it for re-processing.""" with self._lock: for instance_id in list(self._orchestration_in_flight): @@ -1179,7 +1196,7 @@ def AbandonTaskOrchestratorWorkItem(self, request: pb.AbandonOrchestrationTaskRe break return pb.AbandonOrchestrationTaskResponse() - def AbandonTaskEntityWorkItem(self, request: pb.AbandonEntityTaskRequest, context): + def AbandonTaskEntityWorkItem(self, request: pb.AbandonEntityTaskRequest, context: grpc.ServicerContext) -> pb.AbandonEntityTaskResponse: """Abandons an entity work item, restoring it for re-processing.""" with self._lock: for entity in self._entities.values(): @@ -1422,7 +1439,7 @@ def _process_complete_orchestration_action(self, instance: OrchestrationInstance @staticmethod def _clone_history_event(event: pb.HistoryEvent) -> pb.HistoryEvent: - cloned_event = pb.HistoryEvent() + cloned_event: pb.HistoryEvent = pb.HistoryEvent() cloned_event.CopyFrom(event) return cloned_event @@ -1474,7 +1491,7 @@ def _process_create_timer_action(self, instance: OrchestrationInstance, now = datetime.now(timezone.utc) delay = max(0, (fire_at - now).total_seconds()) - def fire_timer(): + def fire_timer() -> None: time.sleep(delay) with self._lock: current_instance = self._instances.get(instance.instance_id) @@ -1522,7 +1539,7 @@ def _process_create_sub_orchestration_action(self, instance: OrchestrationInstan def _watch_sub_orchestration(self, parent_instance_id: str, sub_instance_id: str, task_id: int): """Watches a sub-orchestration for completion and delivers the result to the parent.""" - def watch(): + def watch() -> None: # Wait for sub-orchestration to complete sub_instance = self._wait_for_state( sub_instance_id, @@ -1708,7 +1725,7 @@ def _try_grant_lock(self, pending: PendingLockRequest) -> bool: def _try_grant_pending_locks(self): """Attempts to grant any pending lock requests that can now be fulfilled.""" - still_pending = [] + still_pending: list[PendingLockRequest] = [] for pending in self._pending_lock_requests: if self._can_grant_lock(pending): self._grant_lock(pending) diff --git a/durabletask/worker.py b/durabletask/worker.py index a7b42f2..5e98cc1 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -7,13 +7,14 @@ import logging import os import time +from collections.abc import Callable, Generator, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from threading import Event, Lock, Thread from types import GeneratorType from enum import Enum -from typing import Any, Generator, Sequence, TypeVar, overload +from typing import Any, TypeVar, cast, overload import uuid from packaging.version import InvalidVersion, parse @@ -31,7 +32,6 @@ from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext from durabletask.internal.json_encode_output_exception import JsonEncodeOutputException from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext -from durabletask.internal.proto_task_hub_sidecar_service_stub import ProtoTaskHubSidecarServiceStub import durabletask.internal.helpers as ph import durabletask.internal.exceptions as pe import durabletask.internal.orchestrator_service_pb2 as pb @@ -53,6 +53,7 @@ DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ' DEFAULT_MAXIMUM_TIMER_INTERVAL = timedelta(days=3) _STREAM_CLOSED_SENTINEL = object() +_WorkItem = tuple[Callable[..., Any], Callable[..., Any], tuple[Any, ...], dict[str, Any]] class ConcurrencyOptions: @@ -241,7 +242,7 @@ class OrchestrationWorkItemFilter: name: str """The name of the orchestration to filter.""" - versions: list[str] = field(default_factory=list) + versions: list[str] = field(default_factory=list[str]) """Optional list of versions to filter.""" @@ -251,7 +252,7 @@ class ActivityWorkItemFilter: name: str """The name of the activity to filter.""" - versions: list[str] = field(default_factory=list) + versions: list[str] = field(default_factory=list[str]) """Optional list of versions to filter.""" @@ -283,11 +284,11 @@ class WorkItemFilters: :meth:`TaskHubGrpcWorker.use_work_item_filters` to enable filtering. """ - orchestrations: list[OrchestrationWorkItemFilter] = field(default_factory=list) + orchestrations: list[OrchestrationWorkItemFilter] = field(default_factory=list[OrchestrationWorkItemFilter]) """List of orchestration filters.""" - activities: list[ActivityWorkItemFilter] = field(default_factory=list) + activities: list[ActivityWorkItemFilter] = field(default_factory=list[ActivityWorkItemFilter]) """List of activity filters.""" - entities: list[EntityWorkItemFilter] = field(default_factory=list) + entities: list[EntityWorkItemFilter] = field(default_factory=list[EntityWorkItemFilter]) """List of entity filters.""" @classmethod @@ -296,7 +297,7 @@ def _from_registry(cls, registry: '_Registry') -> 'WorkItemFilters': versions: list[str] = [] v = registry.versioning if v and v.match_strategy == VersionMatchStrategy.STRICT and v.version: - versions = [registry.versioning.version] + versions = [v.version] orchestrations = [ OrchestrationWorkItemFilter(name=name, versions=list(versions)) @@ -335,9 +336,9 @@ def _to_grpc(self) -> pb.WorkItemFilters: class _Registry: - orchestrators: dict[str, task.Orchestrator] - activities: dict[str, task.Activity] - entities: dict[str, task.Entity] + orchestrators: dict[str, task.Orchestrator[Any, Any]] + activities: dict[str, task.Activity[Any, Any]] + entities: dict[str, task.Entity[Any, Any]] versioning: VersioningOptions | None = None def __init__(self): @@ -383,7 +384,7 @@ def add_named_activity(self, name: str, fn: task.Activity[TInput, TOutput]) -> N def get_activity(self, name: str) -> task.Activity[Any, Any] | None: return self.activities.get(name) - def add_entity(self, fn: task.Entity, name: str | None = None) -> str: + def add_entity(self, fn: task.Entity[Any, Any], name: str | None = None) -> str: if fn is None: raise ValueError("An entity function argument is required.") @@ -393,7 +394,7 @@ def add_entity(self, fn: task.Entity, name: str | None = None) -> str: self.add_named_entity(name, fn) return name - def add_named_entity(self, name: str, fn: task.Entity) -> None: + def add_named_entity(self, name: str, fn: task.Entity[Any, Any]) -> None: name = name.lower() EntityInstanceId.validate_entity_name(name) if name in self.entities: @@ -401,7 +402,7 @@ def add_named_entity(self, name: str, fn: task.Entity) -> None: self.entities[name] = fn - def get_entity(self, name: str) -> task.Entity | None: + def get_entity(self, name: str) -> task.Entity[Any, Any] | None: return self.entities.get(name) @@ -570,6 +571,7 @@ def __init__( self._maximum_timer_interval = maximum_timer_interval self._work_item_filters: WorkItemFilters | None = None self._auto_generate_work_item_filters: bool = False + self._runLoop: Thread | None = None @property def concurrency_options(self) -> ConcurrencyOptions: @@ -613,7 +615,7 @@ def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str: ) return self._registry.add_orchestrator(fn) - def add_activity(self, fn: task.Activity) -> str: + def add_activity(self, fn: task.Activity[Any, Any]) -> str: """Registers an activity function with the worker.""" if self._is_running: raise RuntimeError( @@ -621,7 +623,7 @@ def add_activity(self, fn: task.Activity) -> str: ) return self._registry.add_activity(fn) - def add_entity(self, fn: task.Entity, name: str | None = None) -> str: + def add_entity(self, fn: task.Entity[Any, Any], name: str | None = None) -> str: """Registers an entity function with the worker.""" if self._is_running: raise RuntimeError( @@ -692,11 +694,11 @@ def start(self): # Auto-generate work item filters from registry if opted in if self._auto_generate_work_item_filters: - self._work_item_filters = WorkItemFilters._from_registry(self._registry) + self._work_item_filters = WorkItemFilters._from_registry(self._registry) # pyright: ignore[reportPrivateUsage] self._shutdown.clear() - def run_loop(): + def run_loop() -> None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(self._async_run_loop()) @@ -706,12 +708,12 @@ def run_loop(): self._runLoop.start() self._is_running = True - async def _async_run_loop(self): + async def _async_run_loop(self) -> None: self._async_worker_manager.prepare_for_run() worker_task = asyncio.create_task(self._async_worker_manager.run()) - current_channel = self._channel - current_stub = None - current_reader_thread = None + current_channel: grpc.Channel | None = self._channel + current_stub: Any | None = None + current_reader_thread: Thread | None = None conn_retry_count = 0 failure_tracker = FailureTracker( threshold=self._resiliency_options.channel_recreate_failure_threshold, @@ -725,7 +727,7 @@ def get_reconnect_delay_seconds() -> float: cap_seconds=self._resiliency_options.reconnect_backoff_cap_seconds, ) - def create_fresh_connection(): + def create_fresh_connection() -> None: nonlocal current_channel, current_stub, conn_retry_count current_stub = None try: @@ -746,8 +748,11 @@ def create_fresh_connection(): current_stub = None raise - def wrap_with_release(handler, release): - def wrapped(*args, **kwargs): + def wrap_with_release( + handler: Callable[..., Any], + release: Callable[[], None], + ) -> Callable[..., Any]: + def wrapped(*args: Any, **kwargs: Any) -> Any: try: return handler(*args, **kwargs) finally: @@ -756,14 +761,14 @@ def wrapped(*args, **kwargs): return wrapped def submit_work_item( - submit_func, - handler, - cancellation_handler, - request, - stub, - completion_token, - channel, - ): + submit_func: Callable[..., None], + handler: Callable[..., Any], + cancellation_handler: Callable[..., Any], + request: Any, + stub: Any, + completion_token: Any, + channel: grpc.Channel, + ) -> None: release = in_flight_channel_tracker.acquire(channel) try: submit_func( @@ -810,7 +815,7 @@ def invalidate_connection( current_channel = None current_stub = None - def should_invalidate_connection(rpc_error): + def should_invalidate_connection(rpc_error: grpc.RpcError) -> bool: error_code = rpc_error.code() # type: ignore connection_level_errors = { grpc.StatusCode.UNAVAILABLE, @@ -848,7 +853,7 @@ def should_invalidate_connection(rpc_error): assert current_channel is not None stub = current_stub channel = current_channel - capabilities = [] + capabilities: list[Any] = [] if self._payload_store is not None: capabilities.append(pb.WORKER_CAPABILITY_LARGE_PAYLOADS) get_work_items_request = pb.GetWorkItemsRequest( @@ -858,7 +863,7 @@ def should_invalidate_connection(rpc_error): ) if self._work_item_filters is not None: get_work_items_request.workItemFilters.CopyFrom( - self._work_item_filters._to_grpc() + self._work_item_filters._to_grpc() # pyright: ignore[reportPrivateUsage] ) self._response_stream = stub.GetWorkItems(get_work_items_request) self._logger.info( @@ -868,10 +873,10 @@ def should_invalidate_connection(rpc_error): # Use a thread to read from the blocking gRPC stream and forward to asyncio import queue - work_item_queue = queue.Queue() + work_item_queue: queue.Queue[Any] = queue.Queue() saw_message = False - def stream_reader(): + def stream_reader() -> None: try: response_stream = self._response_stream if response_stream is None: @@ -1071,9 +1076,9 @@ def stop(self): def _execute_orchestrator( self, req: pb.OrchestratorRequest, - stub: stubs.TaskHubSidecarServiceStub | ProtoTaskHubSidecarServiceStub, - completionToken, - ): + stub: Any, + completionToken: Any, + ) -> None: instance_id = req.instanceId # De-externalize any large-payload tokens in the incoming request @@ -1130,14 +1135,14 @@ def _execute_orchestrator( is_failed, failure_details=failure_details, parent_trace_context=parent_trace_ctx, - orchestration_trace_context=result._orchestration_trace_context, + orchestration_trace_context=result._orchestration_trace_context, # pyright: ignore[reportPrivateUsage] ) # Include the span ID in the orchestration trace context # so it persists across dispatches. orch_span_id = None - if result._orchestration_trace_context: - orch_span_id = result._orchestration_trace_context.spanID + if result._orchestration_trace_context: # pyright: ignore[reportPrivateUsage] + orch_span_id = result._orchestration_trace_context.spanID # pyright: ignore[reportPrivateUsage] orch_trace_ctx = tracing.build_orchestration_trace_context( start_time_ns, span_id=orch_span_id) @@ -1202,9 +1207,9 @@ def _execute_orchestrator( def _cancel_orchestrator( self, req: pb.OrchestratorRequest, - stub: stubs.TaskHubSidecarServiceStub | ProtoTaskHubSidecarServiceStub, - completionToken, - ): + stub: Any, + completionToken: Any, + ) -> None: stub.AbandonTaskOrchestratorWorkItem( pb.AbandonOrchestrationTaskRequest( completionToken=completionToken @@ -1215,9 +1220,9 @@ def _cancel_orchestrator( def _execute_activity( self, req: pb.ActivityRequest, - stub: stubs.TaskHubSidecarServiceStub | ProtoTaskHubSidecarServiceStub, - completionToken, - ): + stub: Any, + completionToken: Any, + ) -> None: instance_id = req.orchestrationInstance.instanceId # De-externalize any large-payload tokens in the incoming request @@ -1272,9 +1277,9 @@ def _execute_activity( def _cancel_activity( self, req: pb.ActivityRequest, - stub: stubs.TaskHubSidecarServiceStub | ProtoTaskHubSidecarServiceStub, - completionToken, - ): + stub: Any, + completionToken: Any, + ) -> None: stub.AbandonTaskActivityWorkItem( pb.AbandonActivityTaskRequest( completionToken=completionToken @@ -1285,9 +1290,9 @@ def _cancel_activity( def _execute_entity_batch( self, req: pb.EntityBatchRequest | pb.EntityRequest, - stub: stubs.TaskHubSidecarServiceStub | ProtoTaskHubSidecarServiceStub, - completionToken, - ): + stub: Any, + completionToken: Any, + ) -> pb.EntityBatchResult: operation_infos: list[pb.OperationInfo] = [] if isinstance(req, pb.EntityRequest): req, operation_infos = helpers.convert_to_entity_batch_request(req) @@ -1354,7 +1359,7 @@ def _execute_entity_batch( batch_result = pb.EntityBatchResult( results=results, actions=entity_state.get_operation_actions(), - entityState=helpers.get_string_value(shared.to_json(entity_state._current_state)) if entity_state._current_state else None, + entityState=helpers.get_string_value(shared.to_json(entity_state._current_state)) if entity_state._current_state else None, # pyright: ignore[reportPrivateUsage] failureDetails=None, completionToken=completionToken, operationInfos=operation_infos, @@ -1379,9 +1384,9 @@ def _execute_entity_batch( def _cancel_entity_batch( self, req: pb.EntityBatchRequest | pb.EntityRequest, - stub: stubs.TaskHubSidecarServiceStub | ProtoTaskHubSidecarServiceStub, - completionToken, - ): + stub: Any, + completionToken: Any, + ) -> None: stub.AbandonTaskEntityWorkItem( pb.AbandonEntityTaskRequest( completionToken=completionToken @@ -1391,8 +1396,8 @@ def _cancel_entity_batch( class _RuntimeOrchestrationContext(task.OrchestrationContext): - _generator: Generator[task.Task, Any, Any] | None - _previous_task: task.Task | None + _generator: Generator[task.Task[Any], Any, Any] | None + _previous_task: task.Task[Any] | None def __init__(self, instance_id: str, @@ -1404,7 +1409,7 @@ def __init__(self, self._is_complete = False self._result = None self._pending_actions: dict[int, pb.OrchestratorAction] = {} - self._pending_tasks: dict[int, task.CompletableTask] = {} + self._pending_tasks: dict[int, task.CompletableTask[Any]] = {} # Maps entity ID to task ID self._entity_task_id_map: dict[str, tuple[EntityInstanceId, str, int]] = {} self._entity_lock_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {} @@ -1419,7 +1424,7 @@ def __init__(self, self._version: str | None = None self._completion_status: pb.OrchestrationStatus | None = None self._received_events: dict[str, list[Any]] = {} - self._pending_events: dict[str, list[task.CancellableTask]] = {} + self._pending_events: dict[str, list[task.CancellableTask[Any]]] = {} self._new_input: Any | None = None self._save_events = False self._encoded_custom_status: str | None = None @@ -1427,14 +1432,14 @@ def __init__(self, self._orchestration_trace_context: pb.TraceContext | None = None self._maximum_timer_interval = maximum_timer_interval - def run(self, generator: Generator[task.Task, Any, Any]): + def run(self, generator: Generator[task.Task[Any], Any, Any]) -> None: self._generator = generator # TODO: Do something with this task task = next(generator) # this starts the generator # TODO: Check if the task is null? self._previous_task = task - def resume(self): + def resume(self) -> None: if self._generator is None: # This is never expected unless maybe there's an issue with the history raise TypeError( @@ -1446,7 +1451,7 @@ def resume(self): # case is if the user yielded on a WhenAll task and there are still # outstanding child tasks that need to be completed. while self._previous_task is not None and self._previous_task.is_complete: - next_task = None + next_task: Any = None if self._previous_task.is_failed: # Raise the failure as an exception to the generator. # The orchestrator can then either handle the exception or allow it to fail the orchestration. @@ -1591,13 +1596,13 @@ def set_custom_status(self, custom_status: Any) -> None: shared.to_json(custom_status) if custom_status is not None else None ) - def create_timer(self, fire_at: datetime | timedelta) -> task.CancellableTask: + def create_timer(self, fire_at: datetime | timedelta) -> task.TimerTask: return self.create_timer_internal(fire_at) def create_timer_internal( self, fire_at: datetime | timedelta, - retryable_task: task.RetryableTask | None = None, + retryable_task: task.RetryableTask[Any] | None = None, ) -> task.TimerTask: id = self.next_sequence_number() if isinstance(fire_at, timedelta): @@ -1613,7 +1618,7 @@ def create_timer_internal( and self.current_utc_datetime + self._maximum_timer_interval < final_fire_at ): timer_task = task.TimerTask(final_fire_at, self._maximum_timer_interval) - next_fire_at = timer_task._get_next_fire_at(self.current_utc_datetime) + next_fire_at = timer_task._get_next_fire_at(self.current_utc_datetime) # pyright: ignore[reportPrivateUsage] else: timer_task = task.TimerTask() @@ -1643,13 +1648,13 @@ def call_activity( self.call_activity_function_helper( id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False, tags=tags ) - return self._pending_tasks.get(id, task.CompletableTask()) + return cast(task.CompletableTask[TOutput], self._pending_tasks.get(id, task.CompletableTask[TOutput]())) def call_entity( self, entity: EntityInstanceId, operation: str, - input: TInput | None = None, + input: Any = None, ) -> task.CompletableTask[Any]: id = self.next_sequence_number() @@ -1657,13 +1662,13 @@ def call_entity( id, entity, operation, input=input ) - return self._pending_tasks.get(id, task.CompletableTask()) + return self._pending_tasks.get(id, task.CompletableTask[Any]()) def signal_entity( self, entity_id: EntityInstanceId, operation_name: str, - input: TInput | None = None + input: Any = None ) -> None: id = self.next_sequence_number() @@ -1677,7 +1682,7 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> task.CompletableTas self.lock_entities_function_helper( id, entities ) - return self._pending_tasks.get(id, task.CompletableTask()) + return cast(task.CompletableTask[EntityLock], self._pending_tasks.get(id, task.CompletableTask[EntityLock]())) def call_sub_orchestrator( self, @@ -1704,7 +1709,7 @@ def call_sub_orchestrator( instance_id=instance_id, version=orchestrator_version ) - return self._pending_tasks.get(id, task.CompletableTask()) + return cast(task.CompletableTask[TOutput], self._pending_tasks.get(id, task.CompletableTask[TOutput]())) def call_activity_function_helper( self, @@ -1787,8 +1792,8 @@ def call_entity_function_helper( entity_id: EntityInstanceId, operation: str, *, - input: TInput | None = None, - ): + input: Any = None, + ) -> None: if id is None: id = self.next_sequence_number() @@ -1800,7 +1805,7 @@ def call_entity_function_helper( action = ph.new_call_entity_action(id, self.instance_id, entity_id, operation, encoded_input, self.new_uuid()) self._pending_actions[id] = action - fn_task = task.CompletableTask() + fn_task = task.CompletableTask[Any]() self._pending_tasks[id] = fn_task def signal_entity_function_helper( @@ -1808,7 +1813,7 @@ def signal_entity_function_helper( id: int | None, entity_id: EntityInstanceId, operation: str, - input: TInput | None + input: Any = None ) -> None: if id is None: id = self.next_sequence_number() @@ -1823,7 +1828,7 @@ def signal_entity_function_helper( action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input, self.new_uuid()) self._pending_actions[id] = action - def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId]) -> None: + def lock_entities_function_helper(self, id: int | None, entities: list[EntityInstanceId]) -> None: if id is None: id = self.next_sequence_number() @@ -1855,13 +1860,13 @@ def _exit_critical_section(self) -> None: action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message) self._pending_actions[task_id] = action - def wait_for_external_event(self, name: str) -> task.CancellableTask: + def wait_for_external_event(self, name: str) -> task.CancellableTask[Any]: # Check to see if this event has already been received, in which case we # can return it immediately. Otherwise, record out intent to receive an # event with the given name so that we can resume the generator when it # arrives. If there are multiple events with the same name, we return # them in the order they were received. - external_event_task: task.CancellableTask = task.CancellableTask() + external_event_task: task.CancellableTask[Any] = task.CancellableTask() event_name = name.casefold() event_list = self._received_events.get(event_name, None) if event_list: @@ -1890,7 +1895,7 @@ def _cancel_wait() -> None: external_event_task.set_cancel_handler(_cancel_wait) return external_event_task - def continue_as_new(self, new_input, *, save_events: bool = False) -> None: + def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None: if self._is_complete: return @@ -1923,7 +1928,7 @@ def __init__( class _OrchestrationExecutor: - _generator: task.Orchestrator | None = None + _generator: task.Orchestrator[Any, Any] | None = None def __init__( self, @@ -1978,7 +1983,7 @@ def execute( self._logger.debug( f"{instance_id}: Rebuilding local state with {len(old_events)} history event..." ) - ctx._is_replaying = True + ctx._is_replaying = True # pyright: ignore[reportPrivateUsage] for old_event in old_events: self.process_event(ctx, old_event) @@ -1988,7 +1993,7 @@ def execute( self._logger.debug( f"{instance_id}: Processing {len(new_events)} new event(s): {summary}" ) - ctx._is_replaying = False + ctx._is_replaying = False # pyright: ignore[reportPrivateUsage] for new_event in new_events: self.process_event(ctx, new_event) @@ -2006,18 +2011,18 @@ def execute( self._logger.debug(f"{instance_id}: Orchestration {orchestration_name} failed") ctx.set_failed(ex) - if not ctx._is_complete: - task_count = len(ctx._pending_tasks) - event_count = len(ctx._pending_events) + if not ctx._is_complete: # pyright: ignore[reportPrivateUsage] + task_count = len(ctx._pending_tasks) # pyright: ignore[reportPrivateUsage] + event_count = len(ctx._pending_events) # pyright: ignore[reportPrivateUsage] self._logger.info( f"{instance_id}: Orchestrator {orchestration_name} yielded with {task_count} task(s) " f"and {event_count} event(s) outstanding." ) elif ( - ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW + ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW # pyright: ignore[reportPrivateUsage] ): completion_status_str = ph.get_orchestration_status_str( - ctx._completion_status + ctx._completion_status # pyright: ignore[reportPrivateUsage] ) self._logger.info( f"{instance_id}: Orchestration {orchestration_name} completed with status: {completion_status_str}" @@ -2029,8 +2034,8 @@ def execute( f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}" ) return ExecutionResults( - actions=actions, encoded_custom_status=ctx._encoded_custom_status, - orchestration_trace_context=ctx._orchestration_trace_context, + actions=actions, encoded_custom_status=ctx._encoded_custom_status, # pyright: ignore[reportPrivateUsage] + orchestration_trace_context=ctx._orchestration_trace_context, # pyright: ignore[reportPrivateUsage] ) def process_event( @@ -2052,22 +2057,22 @@ def process_event( ) if event.executionStarted.version: - ctx._version = event.executionStarted.version.value + ctx._version = event.executionStarted.version.value # pyright: ignore[reportPrivateUsage] # Store the parent trace context for propagation to child tasks if event.executionStarted.HasField("parentTraceContext"): - ctx._parent_trace_context = event.executionStarted.parentTraceContext + ctx._parent_trace_context = event.executionStarted.parentTraceContext # pyright: ignore[reportPrivateUsage] # Reuse a persisted span ID from a prior dispatch so # activities/timers/sub-orchestrations across all # dispatches share the same parent. On the first # dispatch, generate a new random span ID. if self._persisted_orch_span_id: - ctx._orchestration_trace_context = tracing.reconstruct_trace_context( - ctx._parent_trace_context, + ctx._orchestration_trace_context = tracing.reconstruct_trace_context( # pyright: ignore[reportPrivateUsage] + ctx._parent_trace_context, # pyright: ignore[reportPrivateUsage] self._persisted_orch_span_id) else: - ctx._orchestration_trace_context = tracing.generate_client_trace_context( - parent_trace_context=ctx._parent_trace_context) + ctx._orchestration_trace_context = tracing.generate_client_trace_context( # pyright: ignore[reportPrivateUsage] + parent_trace_context=ctx._parent_trace_context) # pyright: ignore[reportPrivateUsage] if self._registry.versioning: version_failure = self.evaluate_orchestration_versioning( @@ -2085,7 +2090,7 @@ def process_event( # deserialize the input, if any input = None if ( - event.executionStarted.input is not None and event.executionStarted.input.value != "" + event.executionStarted.HasField("input") and event.executionStarted.input.value != "" ): input = shared.from_json(event.executionStarted.input.value) @@ -2094,7 +2099,7 @@ def process_event( ) # this does not execute the generator, only creates it if isinstance(result, GeneratorType): # Start the orchestrator's generator function - ctx.run(result) + ctx.run(cast(Generator[task.Task[Any], Any, Any], result)) else: # This is an orchestrator that doesn't schedule any tasks ctx.set_complete(result, pb.ORCHESTRATION_STATUS_COMPLETED) @@ -2102,7 +2107,7 @@ def process_event( # This history event confirms that the timer was successfully scheduled. # Remove the timerCreated event from the pending action list so we don't schedule it again. timer_id = event.eventId - action = ctx._pending_actions.pop(timer_id, None) + action = ctx._pending_actions.pop(timer_id, None) # pyright: ignore[reportPrivateUsage] if not action: raise _get_non_determinism_error( timer_id, task.get_name(ctx.create_timer) @@ -2121,7 +2126,7 @@ def process_event( ) elif event.HasField("timerFired"): timer_id = event.timerFired.timerId - timer_task = ctx._pending_tasks.pop(timer_id, None) + timer_task = ctx._pending_tasks.pop(timer_id, None) # pyright: ignore[reportPrivateUsage] if not timer_task: # Unexpected event for unknown timer; log and skip. if not ctx.is_replaying: @@ -2130,7 +2135,7 @@ def process_event( ) return if not isinstance(timer_task, task.TimerTask): - if not ctx._is_replaying: + if not ctx._is_replaying: # pyright: ignore[reportPrivateUsage] self._logger.warning( f"{ctx.instance_id}: Ignoring timerFired event with non-timer task ID = {timer_id}." ) @@ -2144,25 +2149,25 @@ def process_event( self._orchestration_name, ctx.instance_id, timer_id, fire_at, scheduled_time_ns=created_ns, - parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, + parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, # pyright: ignore[reportPrivateUsage] ) - next_fire_at = timer_task._handle_timer_fired(event.timerFired.fireAt.ToDatetime()) + next_fire_at = timer_task._handle_timer_fired(event.timerFired.fireAt.ToDatetime()) # pyright: ignore[reportPrivateUsage] if next_fire_at is not None: id = ctx.next_sequence_number() new_action = ph.new_create_timer_action(id, next_fire_at) - ctx._pending_tasks[id] = timer_task - ctx._pending_actions[id] = new_action + ctx._pending_tasks[id] = timer_task # pyright: ignore[reportPrivateUsage] + ctx._pending_actions[id] = new_action # pyright: ignore[reportPrivateUsage] def _cancel_timer() -> None: - ctx._pending_actions.pop(id, None) - ctx._pending_tasks.pop(id, None) + ctx._pending_actions.pop(id, None) # pyright: ignore[reportPrivateUsage] + ctx._pending_tasks.pop(id, None) # pyright: ignore[reportPrivateUsage] timer_task.set_cancel_handler(_cancel_timer) else: - if timer_task._retryable_parent is not None: - activity_action = timer_task._retryable_parent._action + if timer_task._retryable_parent is not None: # pyright: ignore[reportPrivateUsage] + activity_action = timer_task._retryable_parent._action # pyright: ignore[reportPrivateUsage] - if not timer_task._retryable_parent._is_sub_orch: + if not timer_task._retryable_parent._is_sub_orch: # pyright: ignore[reportPrivateUsage] cur_task = activity_action.scheduleTask instance_id = None else: @@ -2172,10 +2177,10 @@ def _cancel_timer() -> None: id=activity_action.id, activity_function=cur_task.name, input=cur_task.input.value, - retry_policy=timer_task._retryable_parent._retry_policy, - is_sub_orch=timer_task._retryable_parent._is_sub_orch, + retry_policy=timer_task._retryable_parent._retry_policy, # pyright: ignore[reportPrivateUsage] + is_sub_orch=timer_task._retryable_parent._is_sub_orch, # pyright: ignore[reportPrivateUsage] instance_id=instance_id, - fn_task=timer_task._retryable_parent, + fn_task=timer_task._retryable_parent, # pyright: ignore[reportPrivateUsage] ) else: ctx.resume() @@ -2183,8 +2188,8 @@ def _cancel_timer() -> None: # This history event confirms that the activity execution was successfully scheduled. # Remove the taskScheduled event from the pending action list so we don't schedule it again. task_id = event.eventId - action = ctx._pending_actions.pop(task_id, None) - activity_task = ctx._pending_tasks.get(task_id, None) + action = ctx._pending_actions.pop(task_id, None) # pyright: ignore[reportPrivateUsage] + activity_task = ctx._pending_tasks.get(task_id, None) # pyright: ignore[reportPrivateUsage] if not action: raise _get_non_determinism_error( task_id, task.get_name(ctx.call_activity) @@ -2213,7 +2218,7 @@ def _cancel_timer() -> None: elif event.HasField("taskCompleted"): # This history event contains the result of a completed activity task. task_id = event.taskCompleted.taskScheduledId - activity_task = ctx._pending_tasks.pop(task_id, None) + activity_task = ctx._pending_tasks.pop(task_id, None) # pyright: ignore[reportPrivateUsage] if not activity_task: # Unexpected completion for unknown task; log and skip. if not ctx.is_replaying: @@ -2230,7 +2235,7 @@ def _cancel_timer() -> None: tracing.emit_client_span( t_type, t_name, t_iid, task_id, client_trace_context=c_ctx, - parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, + parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, # pyright: ignore[reportPrivateUsage] start_time_ns=s_ns, end_time_ns=e_ns, version=t_ver, ) @@ -2241,7 +2246,7 @@ def _cancel_timer() -> None: ctx.resume() elif event.HasField("taskFailed"): task_id = event.taskFailed.taskScheduledId - activity_task = ctx._pending_tasks.pop(task_id, None) + activity_task = ctx._pending_tasks.pop(task_id, None) # pyright: ignore[reportPrivateUsage] if not activity_task: # Unexpected failure for unknown task; log and skip. if not ctx.is_replaying: @@ -2259,27 +2264,29 @@ def _cancel_timer() -> None: tracing.emit_client_span( t_type, t_name, t_iid, task_id, client_trace_context=c_ctx, - parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, + parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, # pyright: ignore[reportPrivateUsage] start_time_ns=s_ns, end_time_ns=e_ns, is_error=True, error_message=str(event.taskFailed.failureDetails.errorMessage), version=t_ver, ) - if isinstance(activity_task, task.RetryableTask): - if activity_task._retry_policy is not None: - next_delay = activity_task.compute_next_delay() - if next_delay is None: - activity_task.fail( - f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", - event.taskFailed.failureDetails, - ) - ctx.resume() - else: - activity_task.increment_attempt_count() - ctx.create_timer_internal(next_delay, activity_task) - elif isinstance(activity_task, task.CompletableTask): - activity_task.fail( + activity_task_obj = cast(object, activity_task) + if isinstance(activity_task_obj, task.RetryableTask): + retryable_activity_task = cast(task.RetryableTask[Any], activity_task_obj) + next_delay = retryable_activity_task.compute_next_delay() + if next_delay is None: + retryable_activity_task.fail( + f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", + event.taskFailed.failureDetails, + ) + ctx.resume() + else: + retryable_activity_task.increment_attempt_count() + ctx.create_timer_internal(next_delay, retryable_activity_task) + elif isinstance(activity_task_obj, task.CompletableTask): + completable_activity_task = cast(task.CompletableTask[Any], activity_task_obj) + completable_activity_task.fail( f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", event.taskFailed.failureDetails, ) @@ -2290,7 +2297,7 @@ def _cancel_timer() -> None: # This history event confirms that the sub-orchestration execution was successfully scheduled. # Remove the subOrchestrationInstanceCreated event from the pending action list so we don't schedule it again. task_id = event.eventId - action = ctx._pending_actions.pop(task_id, None) + action = ctx._pending_actions.pop(task_id, None) # pyright: ignore[reportPrivateUsage] if not action: raise _get_non_determinism_error( task_id, task.get_name(ctx.call_sub_orchestrator) @@ -2320,7 +2327,7 @@ def _cancel_timer() -> None: ) elif event.HasField("subOrchestrationInstanceCompleted"): task_id = event.subOrchestrationInstanceCompleted.taskScheduledId - sub_orch_task = ctx._pending_tasks.pop(task_id, None) + sub_orch_task = ctx._pending_tasks.pop(task_id, None) # pyright: ignore[reportPrivateUsage] if not sub_orch_task: # Unexpected completion for unknown sub-orchestration; log and skip. if not ctx.is_replaying: @@ -2337,7 +2344,7 @@ def _cancel_timer() -> None: tracing.emit_client_span( t_type, t_name, t_iid, task_id, client_trace_context=c_ctx, - parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, + parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, # pyright: ignore[reportPrivateUsage] start_time_ns=s_ns, end_time_ns=e_ns, version=t_ver, ) @@ -2351,7 +2358,7 @@ def _cancel_timer() -> None: elif event.HasField("subOrchestrationInstanceFailed"): failedEvent = event.subOrchestrationInstanceFailed task_id = failedEvent.taskScheduledId - sub_orch_task = ctx._pending_tasks.pop(task_id, None) + sub_orch_task = ctx._pending_tasks.pop(task_id, None) # pyright: ignore[reportPrivateUsage] if not sub_orch_task: # Unexpected failure for unknown sub-orchestration; log and skip. if not ctx.is_replaying: @@ -2368,26 +2375,28 @@ def _cancel_timer() -> None: tracing.emit_client_span( t_type, t_name, t_iid, task_id, client_trace_context=c_ctx, - parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, + parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context, # pyright: ignore[reportPrivateUsage] start_time_ns=s_ns, end_time_ns=e_ns, is_error=True, error_message=str(failedEvent.failureDetails.errorMessage), version=t_ver, ) - if isinstance(sub_orch_task, task.RetryableTask): - if sub_orch_task._retry_policy is not None: - next_delay = sub_orch_task.compute_next_delay() - if next_delay is None: - sub_orch_task.fail( - f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", - failedEvent.failureDetails, - ) - ctx.resume() - else: - sub_orch_task.increment_attempt_count() - ctx.create_timer_internal(next_delay, sub_orch_task) - elif isinstance(sub_orch_task, task.CompletableTask): - sub_orch_task.fail( + sub_orch_task_obj = cast(object, sub_orch_task) + if isinstance(sub_orch_task_obj, task.RetryableTask): + retryable_sub_orch_task = cast(task.RetryableTask[Any], sub_orch_task_obj) + next_delay = retryable_sub_orch_task.compute_next_delay() + if next_delay is None: + retryable_sub_orch_task.fail( + f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", + failedEvent.failureDetails, + ) + ctx.resume() + else: + retryable_sub_orch_task.increment_attempt_count() + ctx.create_timer_internal(next_delay, retryable_sub_orch_task) + elif isinstance(sub_orch_task_obj, task.CompletableTask): + completable_sub_orch_task = cast(task.CompletableTask[Any], sub_orch_task_obj) + completable_sub_orch_task.fail( f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", failedEvent.failureDetails, ) @@ -2395,18 +2404,18 @@ def _cancel_timer() -> None: else: raise TypeError("Unexpected sub-orchestration task type") elif event.HasField("eventRaised"): - if event.eventRaised.name in ctx._entity_task_id_map: - entity_id, operation, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None, None)) + if event.eventRaised.name in ctx._entity_task_id_map: # pyright: ignore[reportPrivateUsage] + entity_id, operation, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None, None)) # pyright: ignore[reportPrivateUsage] self._handle_entity_event_raised(ctx, event, entity_id, task_id, False) - elif event.eventRaised.name in ctx._entity_lock_task_id_map: - entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None)) + elif event.eventRaised.name in ctx._entity_lock_task_id_map: # pyright: ignore[reportPrivateUsage] + entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None)) # pyright: ignore[reportPrivateUsage] self._handle_entity_event_raised(ctx, event, entity_id, task_id, True) else: # event names are case-insensitive event_name = event.eventRaised.name.casefold() if not ctx.is_replaying: self._logger.info(f"{ctx.instance_id} Event raised: {event_name}") - task_list = ctx._pending_events.get(event_name, None) + task_list = ctx._pending_events.get(event_name, None) # pyright: ignore[reportPrivateUsage] decoded_result: Any | None = None if task_list: event_task = task_list.pop(0) @@ -2414,14 +2423,14 @@ def _cancel_timer() -> None: decoded_result = shared.from_json(event.eventRaised.input.value) event_task.complete(decoded_result) if not task_list: - del ctx._pending_events[event_name] + del ctx._pending_events[event_name] # pyright: ignore[reportPrivateUsage] ctx.resume() else: # buffer the event - event_list = ctx._received_events.get(event_name, None) + event_list = ctx._received_events.get(event_name, None) # pyright: ignore[reportPrivateUsage] if not event_list: event_list = [] - ctx._received_events[event_name] = event_list + ctx._received_events[event_name] = event_list # pyright: ignore[reportPrivateUsage] if not ph.is_empty(event.eventRaised.input): decoded_result = shared.from_json(event.eventRaised.input.value) event_list.append(decoded_result) @@ -2457,8 +2466,8 @@ def _cancel_timer() -> None: # This history event confirms that the entity operation was successfully scheduled. # Remove the entityOperationCalled event from the pending action list so we don't schedule it again entity_call_id = event.eventId - action = ctx._pending_actions.pop(entity_call_id, None) - entity_task = ctx._pending_tasks.get(entity_call_id, None) + action = ctx._pending_actions.pop(entity_call_id, None) # pyright: ignore[reportPrivateUsage] + entity_task = ctx._pending_tasks.get(entity_call_id, None) # pyright: ignore[reportPrivateUsage] if not action: raise _get_non_determinism_error( entity_call_id, task.get_name(ctx.call_entity) @@ -2473,12 +2482,12 @@ def _cancel_timer() -> None: operation = event.entityOperationCalled.operation except ValueError: raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'") - ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, operation, entity_call_id) + ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, operation, entity_call_id) # pyright: ignore[reportPrivateUsage] elif event.HasField("entityOperationSignaled"): # This history event confirms that the entity signal was successfully scheduled. # Remove the entityOperationSignaled event from the pending action list so we don't schedule it entity_signal_id = event.eventId - action = ctx._pending_actions.pop(entity_signal_id, None) + action = ctx._pending_actions.pop(entity_signal_id, None) # pyright: ignore[reportPrivateUsage] if not action: raise _get_non_determinism_error( entity_signal_id, task.get_name(ctx.signal_entity) @@ -2491,8 +2500,8 @@ def _cancel_timer() -> None: elif event.HasField("entityLockRequested"): section_id = event.entityLockRequested.criticalSectionId task_id = event.eventId - action = ctx._pending_actions.pop(task_id, None) - entity_task = ctx._pending_tasks.get(task_id, None) + action = ctx._pending_actions.pop(task_id, None) # pyright: ignore[reportPrivateUsage] + entity_task = ctx._pending_tasks.get(task_id, None) # pyright: ignore[reportPrivateUsage] if not action: raise _get_non_determinism_error( task_id, task.get_name(ctx.lock_entities) @@ -2502,19 +2511,19 @@ def _cancel_timer() -> None: raise _get_wrong_action_type_error( task_id, expected_method_name, action ) - ctx._entity_lock_id_map[section_id] = task_id + ctx._entity_lock_id_map[section_id] = task_id # pyright: ignore[reportPrivateUsage] elif event.HasField("entityUnlockSent"): # Remove the unlock tasks as they have already been processed - tasks_to_remove = [] - for task_id, action in ctx._pending_actions.items(): + tasks_to_remove: list[int] = [] + for task_id, action in ctx._pending_actions.items(): # pyright: ignore[reportPrivateUsage] if action.HasField("sendEntityMessage") and action.sendEntityMessage.HasField("entityUnlockSent"): if action.sendEntityMessage.entityUnlockSent.criticalSectionId == event.entityUnlockSent.criticalSectionId: tasks_to_remove.append(task_id) for task_to_remove in tasks_to_remove: - ctx._pending_actions.pop(task_to_remove, None) + ctx._pending_actions.pop(task_to_remove, None) # pyright: ignore[reportPrivateUsage] elif event.HasField("entityLockGranted"): section_id = event.entityLockGranted.criticalSectionId - task_id = ctx._entity_lock_id_map.pop(section_id, None) + task_id = ctx._entity_lock_id_map.pop(section_id, None) # pyright: ignore[reportPrivateUsage] if not task_id: # Unexpected lock grant for unknown section; log and skip. if not ctx.is_replaying: @@ -2522,24 +2531,24 @@ def _cancel_timer() -> None: f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'." ) return - entity_task = ctx._pending_tasks.pop(task_id, None) + entity_task = ctx._pending_tasks.pop(task_id, None) # pyright: ignore[reportPrivateUsage] if not entity_task: if not ctx.is_replaying: self._logger.warning( f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'." ) return - ctx._entity_context.complete_acquire(section_id) + ctx._entity_context.complete_acquire(section_id) # pyright: ignore[reportPrivateUsage] entity_task.complete(EntityLock(ctx)) ctx.resume() elif event.HasField("entityOperationCompleted"): request_id = event.entityOperationCompleted.requestId - entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None)) + entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None)) # pyright: ignore[reportPrivateUsage] if not entity_id: raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'") if not task_id: raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'") - entity_task = ctx._pending_tasks.pop(task_id, None) + entity_task = ctx._pending_tasks.pop(task_id, None) # pyright: ignore[reportPrivateUsage] if not entity_task: if not ctx.is_replaying: self._logger.warning( @@ -2549,19 +2558,19 @@ def _cancel_timer() -> None: result = None if not ph.is_empty(event.entityOperationCompleted.output): result = shared.from_json(event.entityOperationCompleted.output.value) - ctx._entity_context.recover_lock_after_call(entity_id) + ctx._entity_context.recover_lock_after_call(entity_id) # pyright: ignore[reportPrivateUsage] entity_task.complete(result) ctx.resume() elif event.HasField("entityOperationFailed"): request_id = event.entityOperationFailed.requestId - entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None)) + entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None)) # pyright: ignore[reportPrivateUsage] if not entity_id: raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'") if operation is None: raise RuntimeError(f"Could not parse operation name from request ID '{request_id}'") if not task_id: raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'") - entity_task = ctx._pending_tasks.pop(task_id, None) + entity_task = ctx._pending_tasks.pop(task_id, None) # pyright: ignore[reportPrivateUsage] if not entity_task: if not ctx.is_replaying: self._logger.warning( @@ -2573,7 +2582,7 @@ def _cancel_timer() -> None: operation, event.entityOperationFailed.failureDetails ) - ctx._entity_context.recover_lock_after_call(entity_id) + ctx._entity_context.recover_lock_after_call(entity_id) # pyright: ignore[reportPrivateUsage] entity_task.fail(str(failure), failure) ctx.resume() elif event.HasField("orchestratorCompleted"): @@ -2583,14 +2592,14 @@ def _cancel_timer() -> None: # Check if this eventSent corresponds to an entity operation call after being translated to the old # entity protocol by the Durable WebJobs extension. If so, treat this message similarly to # entityOperationCalled and remove the pending action. Also store the entity id and event id for later - action = ctx._pending_actions.pop(event.eventId, None) + action = ctx._pending_actions.pop(event.eventId, None) # pyright: ignore[reportPrivateUsage] if action and action.HasField("sendEntityMessage"): if action.sendEntityMessage.HasField("entityOperationCalled"): entity_id, event_id = self._parse_entity_event_sent_input(event) - ctx._entity_task_id_map[event_id] = (entity_id, action.sendEntityMessage.entityOperationCalled.operation, event.eventId) + ctx._entity_task_id_map[event_id] = (entity_id, action.sendEntityMessage.entityOperationCalled.operation, event.eventId) # pyright: ignore[reportPrivateUsage] elif action.sendEntityMessage.HasField("entityLockRequested"): entity_id, event_id = self._parse_entity_event_sent_input(event) - ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId) + ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId) # pyright: ignore[reportPrivateUsage] else: eventType = event.WhichOneof("eventType") raise task.OrchestrationStateError( @@ -2623,7 +2632,7 @@ def _handle_entity_event_raised(self, raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'") if task_id is None: raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'") - entity_task = ctx._pending_tasks.pop(task_id, None) + entity_task = ctx._pending_tasks.pop(task_id, None) # pyright: ignore[reportPrivateUsage] if not entity_task: raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'") result = None @@ -2631,10 +2640,10 @@ def _handle_entity_event_raised(self, # TODO: Investigate why the event result is wrapped in a dict with "result" key result = shared.from_json(event.eventRaised.input.value)["result"] if is_lock_event: - ctx._entity_context.complete_acquire(event.eventRaised.name) + ctx._entity_context.complete_acquire(event.eventRaised.name) # pyright: ignore[reportPrivateUsage] entity_task.complete(EntityLock(ctx)) else: - ctx._entity_context.recover_lock_after_call(entity_id) + ctx._entity_context.recover_lock_after_call(entity_id) # pyright: ignore[reportPrivateUsage] entity_task.complete(result) ctx.resume() @@ -2754,7 +2763,7 @@ def execute( if not callable(method): raise TypeError(f"Entity operation '{operation}' is not callable") # Execute the entity method - entity_instance._initialize_entity_context(ctx) + entity_instance._initialize_entity_context(ctx) # pyright: ignore[reportPrivateUsage] cache_key = (type(entity_instance), operation) has_required_param = self._entity_method_cache.get(cache_key) if has_required_param is None: @@ -2876,18 +2885,18 @@ def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logg self.concurrency_options = concurrency_options self._logger = logger - self.activity_semaphore = None - self.orchestration_semaphore = None - self.entity_semaphore = None + self.activity_semaphore: asyncio.Semaphore | None = None + self.orchestration_semaphore: asyncio.Semaphore | None = None + self.entity_semaphore: asyncio.Semaphore | None = None # Don't create queues here - defer until we have an event loop - self.activity_queue: asyncio.Queue | None = None - self.orchestration_queue: asyncio.Queue | None = None - self.entity_batch_queue: asyncio.Queue | None = None + self.activity_queue: asyncio.Queue[_WorkItem] | None = None + self.orchestration_queue: asyncio.Queue[_WorkItem] | None = None + self.entity_batch_queue: asyncio.Queue[_WorkItem] | None = None self._queue_event_loop: asyncio.AbstractEventLoop | None = None # Store work items when no event loop is available - self._pending_activity_work: list = [] - self._pending_orchestration_work: list = [] - self._pending_entity_batch_work: list = [] + self._pending_activity_work: list[_WorkItem] = [] + self._pending_orchestration_work: list[_WorkItem] = [] + self._pending_entity_batch_work: list[_WorkItem] = [] self.thread_pool = self._create_thread_pool() self._pool_is_shutdown = False self._shutdown = False @@ -2910,7 +2919,7 @@ def prepare_for_run(self) -> None: self._shutdown = False self._ensure_thread_pool() - def _ensure_queues_for_current_loop(self): + def _ensure_queues_for_current_loop(self) -> None: """Ensure queues are bound to the current event loop.""" try: current_loop = asyncio.get_running_loop() @@ -2926,9 +2935,9 @@ def _ensure_queues_for_current_loop(self): # Need to recreate queues for the current event loop # First, preserve any existing work items - existing_activity_items = [] - existing_orchestration_items = [] - existing_entity_batch_items = [] + existing_activity_items: list[_WorkItem] = [] + existing_orchestration_items: list[_WorkItem] = [] + existing_entity_batch_items: list[_WorkItem] = [] if self.activity_queue is not None: try: @@ -2982,7 +2991,7 @@ def _ensure_queues_for_current_loop(self): self._pending_orchestration_work.clear() self._pending_entity_batch_work.clear() - async def run(self): + async def run(self) -> None: self._ensure_thread_pool() # Ensure queues are properly bound to the current event loop @@ -3016,7 +3025,7 @@ async def run(self): self._logger.error(f"Shutting down worker - Uncaught error in worker manager: {queue_exception}") while self.activity_queue is not None and not self.activity_queue.empty(): try: - func, cancellation_func, args, kwargs = self.activity_queue.get_nowait() + _func, cancellation_func, args, kwargs = self.activity_queue.get_nowait() await self._run_func(cancellation_func, *args, **kwargs) self._logger.error(f"Activity work item args: {args}, kwargs: {kwargs}") except asyncio.QueueEmpty: @@ -3026,7 +3035,7 @@ async def run(self): self._logger.error(f"Uncaught error while cancelling activity work item: {cancellation_exception}") while self.orchestration_queue is not None and not self.orchestration_queue.empty(): try: - func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait() + _func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait() await self._run_func(cancellation_func, *args, **kwargs) self._logger.error(f"Orchestration work item args: {args}, kwargs: {kwargs}") except asyncio.QueueEmpty: @@ -3036,7 +3045,7 @@ async def run(self): self._logger.error(f"Uncaught error while cancelling orchestration work item: {cancellation_exception}") while self.entity_batch_queue is not None and not self.entity_batch_queue.empty(): try: - func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait() + _func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait() await self._run_func(cancellation_func, *args, **kwargs) self._logger.error(f"Entity batch work item args: {args}, kwargs: {kwargs}") except asyncio.QueueEmpty: @@ -3050,9 +3059,9 @@ async def run(self): self.thread_pool.shutdown(wait=True) self._pool_is_shutdown = True - async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore): + async def _consume_queue(self, queue: asyncio.Queue[_WorkItem], semaphore: asyncio.Semaphore) -> None: # List to track running tasks - running_tasks: set[asyncio.Task] = set() + running_tasks: set[asyncio.Task[Any]] = set() while True: # Clean up completed tasks @@ -3076,8 +3085,10 @@ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphor running_tasks.add(task) async def _process_work_item( - self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, cancellation_func, args, kwargs - ): + self, semaphore: asyncio.Semaphore, queue: asyncio.Queue[_WorkItem], + func: Callable[..., Any], cancellation_func: Callable[..., Any], + args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> None: async with semaphore: try: await self._run_func(func, *args, **kwargs) @@ -3087,7 +3098,7 @@ async def _process_work_item( finally: queue.task_done() - async def _run_func(self, func, *args, **kwargs): + async def _run_func(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: if inspect.iscoroutinefunction(func): return await func(*args, **kwargs) else: @@ -3097,7 +3108,10 @@ async def _run_func(self, func, *args, **kwargs): self.thread_pool, lambda: func(*args, **kwargs) ) - def submit_activity(self, func, cancellation_func, *args, **kwargs): + def submit_activity( + self, func: Callable[..., Any], cancellation_func: Callable[..., Any], + *args: Any, **kwargs: Any + ) -> None: if self._shutdown: raise RuntimeError("Cannot submit new work items after shutdown has been initiated.") work_item = (func, cancellation_func, args, kwargs) @@ -3108,7 +3122,10 @@ def submit_activity(self, func, cancellation_func, *args, **kwargs): # No event loop running, store in pending list self._pending_activity_work.append(work_item) - def submit_orchestration(self, func, cancellation_func, *args, **kwargs): + def submit_orchestration( + self, func: Callable[..., Any], cancellation_func: Callable[..., Any], + *args: Any, **kwargs: Any + ) -> None: if self._shutdown: raise RuntimeError("Cannot submit new work items after shutdown has been initiated.") work_item = (func, cancellation_func, args, kwargs) @@ -3119,7 +3136,10 @@ def submit_orchestration(self, func, cancellation_func, *args, **kwargs): # No event loop running, store in pending list self._pending_orchestration_work.append(work_item) - def submit_entity_batch(self, func, cancellation_func, *args, **kwargs): + def submit_entity_batch( + self, func: Callable[..., Any], cancellation_func: Callable[..., Any], + *args: Any, **kwargs: Any + ) -> None: if self._shutdown: raise RuntimeError("Cannot submit new work items after shutdown has been initiated.") work_item = (func, cancellation_func, args, kwargs) @@ -3130,10 +3150,10 @@ def submit_entity_batch(self, func, cancellation_func, *args, **kwargs): # No event loop running, store in pending list self._pending_entity_batch_work.append(work_item) - def shutdown(self): + def shutdown(self) -> None: self._shutdown = True - async def reset_for_new_run(self): + async def reset_for_new_run(self) -> None: """Reset the manager state for a new run.""" self.prepare_for_run() # Clear any existing queues - they'll be recreated when needed @@ -3142,21 +3162,21 @@ async def reset_for_new_run(self): # This ensures no items from previous runs remain try: while not self.activity_queue.empty(): - func, cancellation_func, args, kwargs = self.activity_queue.get_nowait() + _func, cancellation_func, args, kwargs = self.activity_queue.get_nowait() await self._run_func(cancellation_func, *args, **kwargs) except Exception as reset_exception: self._logger.warning(f"Error while clearing activity queue during reset: {reset_exception}") if self.orchestration_queue is not None: try: while not self.orchestration_queue.empty(): - func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait() + _func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait() await self._run_func(cancellation_func, *args, **kwargs) except Exception as reset_exception: self._logger.warning(f"Error while clearing orchestration queue during reset: {reset_exception}") if self.entity_batch_queue is not None: try: while not self.entity_batch_queue.empty(): - func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait() + _func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait() await self._run_func(cancellation_func, *args, **kwargs) except Exception as reset_exception: self._logger.warning(f"Error while clearing entity queue during reset: {reset_exception}") diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..6df52ee --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,17 @@ +{ + "include": [ + "durabletask", + "durabletask-azuremanaged" + ], + "exclude": [ + "**/__pycache__", + "**/*_pb2.py", + "**/*_pb2.pyi", + "**/*_pb2_grpc.py", + "**/*_pb2_grpc.pyi", + "**/.venv*", + ".venv*" + ], + "pythonVersion": "3.10", + "typeCheckingMode": "strict" +}