From 6ac8d147a01ca59a816dd02aa3a51dbe3261265a Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Mon, 27 Oct 2025 23:09:02 -0500 Subject: [PATCH 01/10] add options to client argument and e2e test Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- durabletask/aio/client.py | 23 +++-- durabletask/aio/internal/shared.py | 51 +++++++--- durabletask/client.py | 23 +++-- durabletask/internal/shared.py | 39 ++++++-- .../test_grpc_aio_channel_options.py | 94 +++++++++++++++++++ .../durabletask/test_grpc_channel_options.py | 81 ++++++++++++++++ tests/durabletask/test_orchestration_e2e.py | 7 +- 7 files changed, 275 insertions(+), 43 deletions(-) create mode 100644 tests/durabletask/test_grpc_aio_channel_options.py create mode 100644 tests/durabletask/test_grpc_channel_options.py diff --git a/durabletask/aio/client.py b/durabletask/aio/client.py index 4ec9bbf..a295c7f 100644 --- a/durabletask/aio/client.py +++ b/durabletask/aio/client.py @@ -20,15 +20,17 @@ class AsyncTaskHubGrpcClient: - - def __init__(self, *, - host_address: Optional[str] = None, - metadata: Optional[list[tuple[str, str]]] = None, - log_handler: Optional[logging.Handler] = None, - log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False, - interceptors: Optional[Sequence[ClientInterceptor]] = None): - + def __init__( + self, + *, + host_address: Optional[str] = None, + metadata: Optional[list[tuple[str, str]]] = None, + log_handler: Optional[logging.Handler] = None, + log_formatter: Optional[logging.Formatter] = None, + secure_channel: bool = False, + interceptors: Optional[Sequence[ClientInterceptor]] = None, + channel_options: Optional[Sequence[tuple[str, Any]]] = None, + ): if interceptors is not None: interceptors = list(interceptors) if metadata is not None: @@ -41,7 +43,8 @@ def __init__(self, *, channel = get_grpc_aio_channel( host_address=host_address, secure_channel=secure_channel, - interceptors=interceptors + interceptors=interceptors, + options=channel_options, ) self._channel = channel self._stub = stubs.TaskHubSidecarServiceStub(channel) diff --git a/durabletask/aio/internal/shared.py b/durabletask/aio/internal/shared.py index 6bdb256..87266a1 100644 --- a/durabletask/aio/internal/shared.py +++ b/durabletask/aio/internal/shared.py @@ -1,49 +1,72 @@ # Copyright (c) The Dapr Authors. # Licensed under the MIT License. -from typing import Optional, Sequence, Union +from typing import Any, Optional, Sequence, Union import grpc from grpc import aio as grpc_aio from durabletask.internal.shared import ( - get_default_host_address, - SECURE_PROTOCOLS, INSECURE_PROTOCOLS, + SECURE_PROTOCOLS, + get_default_host_address, ) - ClientInterceptor = Union[ grpc_aio.UnaryUnaryClientInterceptor, grpc_aio.UnaryStreamClientInterceptor, grpc_aio.StreamUnaryClientInterceptor, - grpc_aio.StreamStreamClientInterceptor + grpc_aio.StreamStreamClientInterceptor, ] def get_grpc_aio_channel( - host_address: Optional[str], - secure_channel: bool = False, - interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc_aio.Channel: + host_address: Optional[str], + secure_channel: bool = False, + interceptors: Optional[Sequence[ClientInterceptor]] = None, + options: Optional[Sequence[tuple[str, Any]]] = None, +) -> grpc_aio.Channel: + """create a grpc asyncio channel + Args: + host_address: The host address of the gRPC server. If None, uses the default address. + secure_channel: Whether to use a secure channel (TLS/SSL). Defaults to False. + interceptors: Optional sequence of client interceptors to apply to the channel. + options: Optional sequence of gRPC channel options as (key, value) tuples. Keys defined in https://grpc.github.io/grpc/core/group__grpc__arg__keys.html + """ if host_address is None: host_address = get_default_host_address() for protocol in SECURE_PROTOCOLS: if host_address.lower().startswith(protocol): secure_channel = True - host_address = host_address[len(protocol):] + host_address = host_address[len(protocol) :] break for protocol in INSECURE_PROTOCOLS: if host_address.lower().startswith(protocol): secure_channel = False - host_address = host_address[len(protocol):] + host_address = host_address[len(protocol) :] break + # Create the base channel if secure_channel: - channel = grpc_aio.secure_channel(host_address, grpc.ssl_channel_credentials(), interceptors=interceptors) - else: - channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors) + if options is not None: + return grpc_aio.secure_channel( + host_address, + grpc.ssl_channel_credentials(), + interceptors=interceptors, + options=options, + ) + return grpc_aio.secure_channel( + host_address, grpc.ssl_channel_credentials(), interceptors=interceptors + ) - return channel + if options is not None: + # validate all options keys prefix starts with `grpc.` + if not all(key.startswith('grpc.') for key, _ in options): + raise ValueError( + f'All options keys must start with `grpc.`. Invalid options: {options}' + ) + return grpc_aio.insecure_channel(host_address, interceptors=interceptors, options=options) + return grpc_aio.insecure_channel(host_address, interceptors=interceptors) diff --git a/durabletask/client.py b/durabletask/client.py index b155bd6..0c884f4 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -91,15 +91,17 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op class TaskHubGrpcClient: - - def __init__(self, *, - host_address: Optional[str] = None, - metadata: Optional[list[tuple[str, str]]] = None, - log_handler: Optional[logging.Handler] = None, - log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False, - interceptors: Optional[Sequence[shared.ClientInterceptor]] = None): - + def __init__( + self, + *, + host_address: Optional[str] = None, + metadata: Optional[list[tuple[str, str]]] = None, + log_handler: Optional[logging.Handler] = None, + log_formatter: Optional[logging.Formatter] = None, + secure_channel: bool = False, + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, + channel_options: Optional[Sequence[tuple[str, Any]]] = None, + ): # If the caller provided metadata, we need to create a new interceptor for it and # add it to the list of interceptors. if interceptors is not None: @@ -114,7 +116,8 @@ def __init__(self, *, channel = shared.get_grpc_channel( host_address=host_address, secure_channel=secure_channel, - interceptors=interceptors + interceptors=interceptors, + options=channel_options, ) self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 22ac3df..d29ecc4 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -51,9 +51,19 @@ def get_default_host_address() -> str: def get_grpc_channel( - host_address: Optional[str], - secure_channel: bool = False, - interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc.Channel: + host_address: Optional[str], + secure_channel: bool = False, + interceptors: Optional[Sequence[ClientInterceptor]] = None, + options: Optional[Sequence[tuple[str, Any]]] = None, +) -> grpc.Channel: + """create a grpc channel + + Args: + host_address: The host address of the gRPC server. If None, uses the default address. + secure_channel: Whether to use a secure channel (TLS/SSL). Defaults to False. + interceptors: Optional sequence of client interceptors to apply to the channel. + options: Optional sequence of gRPC channel options as (key, value) tuples. Keys defined in https://grpc.github.io/grpc/core/group__grpc__arg__keys.html + """ if host_address is None: host_address = get_default_host_address() @@ -61,21 +71,34 @@ def get_grpc_channel( if host_address.lower().startswith(protocol): secure_channel = True # remove the protocol from the host name - host_address = host_address[len(protocol):] + host_address = host_address[len(protocol) :] break for protocol in INSECURE_PROTOCOLS: if host_address.lower().startswith(protocol): secure_channel = False # remove the protocol from the host name - host_address = host_address[len(protocol):] + host_address = host_address[len(protocol) :] break # Create the base channel - if secure_channel: - channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials()) + if options is not None: + # validate all options keys prefix starts with `grpc.` + if not all(key.startswith('grpc.') for key, _ in options): + raise ValueError( + f'All options keys must start with `grpc.`. Invalid options: {options}' + ) + if secure_channel: + channel = grpc.secure_channel( + host_address, grpc.ssl_channel_credentials(), options=options + ) + else: + channel = grpc.insecure_channel(host_address, options=options) else: - channel = grpc.insecure_channel(host_address) + if secure_channel: + channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials()) + else: + channel = grpc.insecure_channel(host_address) # Apply interceptors ONLY if they exist if interceptors: diff --git a/tests/durabletask/test_grpc_aio_channel_options.py b/tests/durabletask/test_grpc_aio_channel_options.py new file mode 100644 index 0000000..54830c8 --- /dev/null +++ b/tests/durabletask/test_grpc_aio_channel_options.py @@ -0,0 +1,94 @@ +import json +from unittest.mock import patch + +import pytest + +from durabletask.aio.internal.shared import get_grpc_aio_channel + +HOST_ADDRESS = 'localhost:50051' + + +def _find_option(options, key): + for k, v in options: + if k == key: + return v + raise AssertionError(f'Option with key {key} not found in options: {options}') + + +def test_aio_channel_passes_base_options_and_max_lengths(): + base_options = [ + ('grpc.max_send_message_length', 4321), + ('grpc.max_receive_message_length', 8765), + ('grpc.primary_user_agent', 'durabletask-aio-tests'), + ] + with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel: + get_grpc_aio_channel(HOST_ADDRESS, False, options=base_options) + # Ensure called with options kwarg + assert mock_channel.call_count == 1 + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert 'options' in kwargs + opts = kwargs['options'] + # Check our base options made it through + assert ('grpc.max_send_message_length', 4321) in opts + assert ('grpc.max_receive_message_length', 8765) in opts + assert ('grpc.primary_user_agent', 'durabletask-aio-tests') in opts + + +def test_aio_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyPatch): + # retry grpc option + # service_config ref => https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto#L44 + max_attempts = 4 + initial_backoff_ms = 250 + max_backoff_ms = 2000 + backoff_multiplier = 1.5 + codes = ['RESOURCE_EXHAUSTED'] + service_config = { + 'methodConfig': [ + { + 'name': [{'service': ''}], # match all services/methods + 'retryPolicy': { + 'maxAttempts': max_attempts, + 'initialBackoff': f'{initial_backoff_ms / 1000.0}s', + 'maxBackoff': f'{max_backoff_ms / 1000.0}s', + 'backoffMultiplier': backoff_multiplier, + 'retryableStatusCodes': codes, + }, + } + ] + } + + base_options = [('grpc.service_config', json.dumps(service_config))] + + with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel: + get_grpc_aio_channel(HOST_ADDRESS, False, options=base_options) + + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert 'options' in kwargs + opts = kwargs['options'] + + # Retry service config present and parses correctly + svc_cfg_str = _find_option(opts, 'grpc.service_config') + svc_cfg = json.loads(svc_cfg_str) + assert 'methodConfig' in svc_cfg and isinstance(svc_cfg['methodConfig'], list) + retry_policy = svc_cfg['methodConfig'][0]['retryPolicy'] + assert retry_policy['maxAttempts'] == 4 + assert retry_policy['initialBackoff'] == f'{250 / 1000.0}s' + assert retry_policy['maxBackoff'] == f'{2000 / 1000.0}s' + assert retry_policy['backoffMultiplier'] == 1.5 + # Codes are upper-cased list + assert 'RESOURCE_EXHAUSTED' in retry_policy['retryableStatusCodes'] + + +def test_aio_secure_channel_receives_options_when_secure_true(): + base_options = [('grpc.max_receive_message_length', 999999)] + with ( + patch('durabletask.aio.internal.shared.grpc_aio.secure_channel') as mock_channel, + patch('grpc.ssl_channel_credentials') as mock_credentials, + ): + get_grpc_aio_channel(HOST_ADDRESS, True, options=base_options) + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert args[1] == mock_credentials.return_value + assert ('grpc.max_receive_message_length', 999999) in kwargs.get('options', []) diff --git a/tests/durabletask/test_grpc_channel_options.py b/tests/durabletask/test_grpc_channel_options.py new file mode 100644 index 0000000..b8ac533 --- /dev/null +++ b/tests/durabletask/test_grpc_channel_options.py @@ -0,0 +1,81 @@ +import json +from unittest.mock import ANY, patch + +import pytest + +from durabletask.internal.shared import get_grpc_channel + +HOST_ADDRESS = 'localhost:50051' + + +def _find_option(options, key): + for k, v in options: + if k == key: + return v + raise AssertionError(f'Option with key {key} not found in options: {options}') + + +def test_sync_channel_passes_base_options_and_max_lengths(): + base_options = [ + ('grpc.max_send_message_length', 1234), + ('grpc.max_receive_message_length', 5678), + ('grpc.primary_user_agent', 'durabletask-tests'), + ] + with patch('grpc.insecure_channel') as mock_channel: + get_grpc_channel(HOST_ADDRESS, False, options=base_options) + # Ensure called with options kwarg + assert mock_channel.call_count == 1 + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert 'options' in kwargs + opts = kwargs['options'] + # Check our base options made it through + assert ('grpc.max_send_message_length', 1234) in opts + assert ('grpc.max_receive_message_length', 5678) in opts + assert ('grpc.primary_user_agent', 'durabletask-tests') in opts + + +def test_sync_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyPatch): + # retry grpc option + # service_config ref => https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto#L44 + max_attempts = 4 + initial_backoff_ms = 250 + max_backoff_ms = 2000 + backoff_multiplier = 1.5 + codes = ['ABORTED'] + service_config = { + 'methodConfig': [ + { + 'name': [{'service': ''}], # match all services/methods + 'retryPolicy': { + 'maxAttempts': max_attempts, + 'initialBackoff': f'{initial_backoff_ms / 1000.0}s', + 'maxBackoff': f'{max_backoff_ms / 1000.0}s', + 'backoffMultiplier': backoff_multiplier, + 'retryableStatusCodes': codes, + }, + } + ] + } + + base_options = [('grpc.service_config', json.dumps(service_config))] + + with patch('grpc.insecure_channel') as mock_channel: + get_grpc_channel(HOST_ADDRESS, False, options=base_options) + + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert 'options' in kwargs + opts = kwargs['options'] + + # Retry service config present and parses correctly + svc_cfg_str = _find_option(opts, 'grpc.service_config') + svc_cfg = json.loads(svc_cfg_str) + assert 'methodConfig' in svc_cfg and isinstance(svc_cfg['methodConfig'], list) + retry_policy = svc_cfg['methodConfig'][0]['retryPolicy'] + assert retry_policy['maxAttempts'] == 4 + assert retry_policy['initialBackoff'] == f'{250 / 1000.0}s' + assert retry_policy['maxBackoff'] == f'{2000 / 1000.0}s' + assert retry_policy['backoffMultiplier'] == 1.5 + # Codes are upper-cased list + assert 'ABORTED' in retry_policy['retryableStatusCodes'] diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index f5651ff..5825d37 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -28,7 +28,12 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_orchestrator(empty_orchestrator) w.start() - c = client.TaskHubGrpcClient() + # set a custom max send length option + c = client.TaskHubGrpcClient( + channel_options=[ + ('grpc.max_send_message_length', 1024 * 1024), # 1MB + ] + ) id = c.schedule_new_orchestration(empty_orchestrator) state = c.wait_for_orchestration_completion(id, timeout=30) From e33a8d5c97a9bc8f7734b18d8c66f583e38b99d0 Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Mon, 27 Oct 2025 23:13:58 -0500 Subject: [PATCH 02/10] correct docstring info on env var names Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- durabletask/internal/shared.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index d29ecc4..f613b3d 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -30,8 +30,8 @@ def get_default_host_address() -> str: Honors environment variables if present; otherwise defaults to localhost:4001. Supported environment variables (checked in order): - - DURABLETASK_GRPC_ENDPOINT (e.g., "localhost:4001", "grpcs://host:443") - - DURABLETASK_GRPC_HOST and DURABLETASK_GRPC_PORT + - DAPR_GRPC_ENDPOINT (e.g., "localhost:4001", "grpcs://host:443") + - DAPR_GRPC_HOST/DAPR_RUNTIME_HOST and DAPR_GRPC_PORT """ import os @@ -59,7 +59,7 @@ def get_grpc_channel( """create a grpc channel Args: - host_address: The host address of the gRPC server. If None, uses the default address. + host_address: The host address of the gRPC server. If None, uses the default address (as defined in get_default_host_address above). secure_channel: Whether to use a secure channel (TLS/SSL). Defaults to False. interceptors: Optional sequence of client interceptors to apply to the channel. options: Optional sequence of gRPC channel options as (key, value) tuples. Keys defined in https://grpc.github.io/grpc/core/group__grpc__arg__keys.html From 4eb8a9fe9816f16c82145685ff8b8ac513176a48 Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Fri, 31 Oct 2025 10:38:36 -0500 Subject: [PATCH 03/10] share validating grpc options Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- durabletask/aio/internal/shared.py | 14 ++-- durabletask/internal/shared.py | 17 +++-- .../test_grpc_aio_channel_options.py | 72 +++++++++---------- .../durabletask/test_grpc_channel_options.py | 66 ++++++++--------- tests/durabletask/test_orchestration_e2e.py | 2 +- 5 files changed, 89 insertions(+), 82 deletions(-) diff --git a/durabletask/aio/internal/shared.py b/durabletask/aio/internal/shared.py index edbb515..113f73b 100644 --- a/durabletask/aio/internal/shared.py +++ b/durabletask/aio/internal/shared.py @@ -1,15 +1,17 @@ # Copyright (c) The Dapr Authors. # Licensed under the MIT License. -from typing import Any, Optional, Sequence, Union +from typing import Any, Dict, Optional, Sequence, Union import grpc from grpc import aio as grpc_aio +from grpc.aio import ChannelArgumentType from durabletask.internal.shared import ( INSECURE_PROTOCOLS, SECURE_PROTOCOLS, get_default_host_address, + validate_grpc_options, ) ClientInterceptor = Union[ @@ -24,7 +26,7 @@ def get_grpc_aio_channel( host_address: Optional[str], secure_channel: bool = False, interceptors: Optional[Sequence[ClientInterceptor]] = None, - options: Optional[Sequence[tuple[str, Any]]] = None, + options: Optional[ChannelArgumentType] = None, ) -> grpc_aio.Channel: """create a grpc asyncio channel @@ -50,13 +52,9 @@ def get_grpc_aio_channel( break # channel interceptors/options - channel_kwargs = dict(interceptors=interceptors) + channel_kwargs: Dict[str, ChannelArgumentType | Sequence[ClientInterceptor]] = dict(interceptors=interceptors) if options is not None: - # validate all options keys prefix starts with `grpc.` - if not all(key.startswith('grpc.') for key, _ in options): - raise ValueError( - f'All options keys must start with `grpc.`. Invalid options: {options}' - ) + validate_grpc_options(options) channel_kwargs["options"] = options if secure_channel: diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 9ecd722..9c7f111 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -9,6 +9,7 @@ from typing import Any, Optional, Sequence, Union import grpc +from grpc.aio import ChannelArgumentType ClientInterceptor = Union[ grpc.UnaryUnaryClientInterceptor, @@ -50,6 +51,17 @@ def get_default_host_address() -> str: return "localhost:4001" +def validate_grpc_options(options: ChannelArgumentType): + """Validate that all gRPC options are valid. Mainly checking keys. Values can be string, int, float, bool and pointer""" + for key, value in options: + if not isinstance(key, str): + raise ValueError(f"gRPC option key must be a string. Invalid key: {key}") + if not all(key.startswith("grpc.") for key, _ in options): + raise ValueError( + f"All options keys must start with `grpc.`. Invalid options: {options}" + ) + + def get_grpc_channel( host_address: Optional[str], secure_channel: bool = False, @@ -84,10 +96,7 @@ def get_grpc_channel( # Create the base channel if options is not None: # validate all options keys prefix starts with `grpc.` - if not all(key.startswith('grpc.') for key, _ in options): - raise ValueError( - f'All options keys must start with `grpc.`. Invalid options: {options}' - ) + validate_grpc_options(options) if secure_channel: channel = grpc.secure_channel( host_address, grpc.ssl_channel_credentials(), options=options diff --git a/tests/durabletask/test_grpc_aio_channel_options.py b/tests/durabletask/test_grpc_aio_channel_options.py index 54830c8..2f64577 100644 --- a/tests/durabletask/test_grpc_aio_channel_options.py +++ b/tests/durabletask/test_grpc_aio_channel_options.py @@ -5,34 +5,34 @@ from durabletask.aio.internal.shared import get_grpc_aio_channel -HOST_ADDRESS = 'localhost:50051' +HOST_ADDRESS = "localhost:50051" def _find_option(options, key): for k, v in options: if k == key: return v - raise AssertionError(f'Option with key {key} not found in options: {options}') + raise AssertionError(f"Option with key {key} not found in options: {options}") def test_aio_channel_passes_base_options_and_max_lengths(): base_options = [ - ('grpc.max_send_message_length', 4321), - ('grpc.max_receive_message_length', 8765), - ('grpc.primary_user_agent', 'durabletask-aio-tests'), + ("grpc.max_send_message_length", 4321), + ("grpc.max_receive_message_length", 8765), + ("grpc.primary_user_agent", "durabletask-aio-tests"), ] - with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel: + with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel: get_grpc_aio_channel(HOST_ADDRESS, False, options=base_options) # Ensure called with options kwarg assert mock_channel.call_count == 1 args, kwargs = mock_channel.call_args assert args[0] == HOST_ADDRESS - assert 'options' in kwargs - opts = kwargs['options'] + assert "options" in kwargs + opts = kwargs["options"] # Check our base options made it through - assert ('grpc.max_send_message_length', 4321) in opts - assert ('grpc.max_receive_message_length', 8765) in opts - assert ('grpc.primary_user_agent', 'durabletask-aio-tests') in opts + assert ("grpc.max_send_message_length", 4321) in opts + assert ("grpc.max_receive_message_length", 8765) in opts + assert ("grpc.primary_user_agent", "durabletask-aio-tests") in opts def test_aio_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyPatch): @@ -42,53 +42,53 @@ def test_aio_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyPa initial_backoff_ms = 250 max_backoff_ms = 2000 backoff_multiplier = 1.5 - codes = ['RESOURCE_EXHAUSTED'] + codes = ["RESOURCE_EXHAUSTED"] service_config = { - 'methodConfig': [ + "methodConfig": [ { - 'name': [{'service': ''}], # match all services/methods - 'retryPolicy': { - 'maxAttempts': max_attempts, - 'initialBackoff': f'{initial_backoff_ms / 1000.0}s', - 'maxBackoff': f'{max_backoff_ms / 1000.0}s', - 'backoffMultiplier': backoff_multiplier, - 'retryableStatusCodes': codes, + "name": [{"service": ""}], # match all services/methods + "retryPolicy": { + "maxAttempts": max_attempts, + "initialBackoff": f"{initial_backoff_ms / 1000.0}s", + "maxBackoff": f"{max_backoff_ms / 1000.0}s", + "backoffMultiplier": backoff_multiplier, + "retryableStatusCodes": codes, }, } ] } - base_options = [('grpc.service_config', json.dumps(service_config))] + base_options = [("grpc.service_config", json.dumps(service_config))] - with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel: + with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel: get_grpc_aio_channel(HOST_ADDRESS, False, options=base_options) args, kwargs = mock_channel.call_args assert args[0] == HOST_ADDRESS - assert 'options' in kwargs - opts = kwargs['options'] + assert "options" in kwargs + opts = kwargs["options"] # Retry service config present and parses correctly - svc_cfg_str = _find_option(opts, 'grpc.service_config') + svc_cfg_str = _find_option(opts, "grpc.service_config") svc_cfg = json.loads(svc_cfg_str) - assert 'methodConfig' in svc_cfg and isinstance(svc_cfg['methodConfig'], list) - retry_policy = svc_cfg['methodConfig'][0]['retryPolicy'] - assert retry_policy['maxAttempts'] == 4 - assert retry_policy['initialBackoff'] == f'{250 / 1000.0}s' - assert retry_policy['maxBackoff'] == f'{2000 / 1000.0}s' - assert retry_policy['backoffMultiplier'] == 1.5 + assert "methodConfig" in svc_cfg and isinstance(svc_cfg["methodConfig"], list) + retry_policy = svc_cfg["methodConfig"][0]["retryPolicy"] + assert retry_policy["maxAttempts"] == 4 + assert retry_policy["initialBackoff"] == f"{250 / 1000.0}s" + assert retry_policy["maxBackoff"] == f"{2000 / 1000.0}s" + assert retry_policy["backoffMultiplier"] == 1.5 # Codes are upper-cased list - assert 'RESOURCE_EXHAUSTED' in retry_policy['retryableStatusCodes'] + assert "RESOURCE_EXHAUSTED" in retry_policy["retryableStatusCodes"] def test_aio_secure_channel_receives_options_when_secure_true(): - base_options = [('grpc.max_receive_message_length', 999999)] + base_options = [("grpc.max_receive_message_length", 999999)] with ( - patch('durabletask.aio.internal.shared.grpc_aio.secure_channel') as mock_channel, - patch('grpc.ssl_channel_credentials') as mock_credentials, + patch("durabletask.aio.internal.shared.grpc_aio.secure_channel") as mock_channel, + patch("grpc.ssl_channel_credentials") as mock_credentials, ): get_grpc_aio_channel(HOST_ADDRESS, True, options=base_options) args, kwargs = mock_channel.call_args assert args[0] == HOST_ADDRESS assert args[1] == mock_credentials.return_value - assert ('grpc.max_receive_message_length', 999999) in kwargs.get('options', []) + assert ("grpc.max_receive_message_length", 999999) in kwargs.get("options", []) diff --git a/tests/durabletask/test_grpc_channel_options.py b/tests/durabletask/test_grpc_channel_options.py index b8ac533..841d75b 100644 --- a/tests/durabletask/test_grpc_channel_options.py +++ b/tests/durabletask/test_grpc_channel_options.py @@ -1,38 +1,38 @@ import json -from unittest.mock import ANY, patch +from unittest.mock import patch import pytest from durabletask.internal.shared import get_grpc_channel -HOST_ADDRESS = 'localhost:50051' +HOST_ADDRESS = "localhost:50051" def _find_option(options, key): for k, v in options: if k == key: return v - raise AssertionError(f'Option with key {key} not found in options: {options}') + raise AssertionError(f"Option with key {key} not found in options: {options}") def test_sync_channel_passes_base_options_and_max_lengths(): base_options = [ - ('grpc.max_send_message_length', 1234), - ('grpc.max_receive_message_length', 5678), - ('grpc.primary_user_agent', 'durabletask-tests'), + ("grpc.max_send_message_length", 1234), + ("grpc.max_receive_message_length", 5678), + ("grpc.primary_user_agent", "durabletask-tests"), ] - with patch('grpc.insecure_channel') as mock_channel: + with patch("grpc.insecure_channel") as mock_channel: get_grpc_channel(HOST_ADDRESS, False, options=base_options) # Ensure called with options kwarg assert mock_channel.call_count == 1 args, kwargs = mock_channel.call_args assert args[0] == HOST_ADDRESS - assert 'options' in kwargs - opts = kwargs['options'] + assert "options" in kwargs + opts = kwargs["options"] # Check our base options made it through - assert ('grpc.max_send_message_length', 1234) in opts - assert ('grpc.max_receive_message_length', 5678) in opts - assert ('grpc.primary_user_agent', 'durabletask-tests') in opts + assert ("grpc.max_send_message_length", 1234) in opts + assert ("grpc.max_receive_message_length", 5678) in opts + assert ("grpc.primary_user_agent", "durabletask-tests") in opts def test_sync_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyPatch): @@ -42,40 +42,40 @@ def test_sync_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyP initial_backoff_ms = 250 max_backoff_ms = 2000 backoff_multiplier = 1.5 - codes = ['ABORTED'] + codes = ["ABORTED"] service_config = { - 'methodConfig': [ + "methodConfig": [ { - 'name': [{'service': ''}], # match all services/methods - 'retryPolicy': { - 'maxAttempts': max_attempts, - 'initialBackoff': f'{initial_backoff_ms / 1000.0}s', - 'maxBackoff': f'{max_backoff_ms / 1000.0}s', - 'backoffMultiplier': backoff_multiplier, - 'retryableStatusCodes': codes, + "name": [{"service": ""}], # match all services/methods + "retryPolicy": { + "maxAttempts": max_attempts, + "initialBackoff": f"{initial_backoff_ms / 1000.0}s", + "maxBackoff": f"{max_backoff_ms / 1000.0}s", + "backoffMultiplier": backoff_multiplier, + "retryableStatusCodes": codes, }, } ] } - base_options = [('grpc.service_config', json.dumps(service_config))] + base_options = [("grpc.service_config", json.dumps(service_config))] - with patch('grpc.insecure_channel') as mock_channel: + with patch("grpc.insecure_channel") as mock_channel: get_grpc_channel(HOST_ADDRESS, False, options=base_options) args, kwargs = mock_channel.call_args assert args[0] == HOST_ADDRESS - assert 'options' in kwargs - opts = kwargs['options'] + assert "options" in kwargs + opts = kwargs["options"] # Retry service config present and parses correctly - svc_cfg_str = _find_option(opts, 'grpc.service_config') + svc_cfg_str = _find_option(opts, "grpc.service_config") svc_cfg = json.loads(svc_cfg_str) - assert 'methodConfig' in svc_cfg and isinstance(svc_cfg['methodConfig'], list) - retry_policy = svc_cfg['methodConfig'][0]['retryPolicy'] - assert retry_policy['maxAttempts'] == 4 - assert retry_policy['initialBackoff'] == f'{250 / 1000.0}s' - assert retry_policy['maxBackoff'] == f'{2000 / 1000.0}s' - assert retry_policy['backoffMultiplier'] == 1.5 + assert "methodConfig" in svc_cfg and isinstance(svc_cfg["methodConfig"], list) + retry_policy = svc_cfg["methodConfig"][0]["retryPolicy"] + assert retry_policy["maxAttempts"] == 4 + assert retry_policy["initialBackoff"] == f"{250 / 1000.0}s" + assert retry_policy["maxBackoff"] == f"{2000 / 1000.0}s" + assert retry_policy["backoffMultiplier"] == 1.5 # Codes are upper-cased list - assert 'ABORTED' in retry_policy['retryableStatusCodes'] + assert "ABORTED" in retry_policy["retryableStatusCodes"] diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 634dfd9..b60c035 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -30,7 +30,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): # set a custom max send length option c = client.TaskHubGrpcClient( channel_options=[ - ('grpc.max_send_message_length', 1024 * 1024), # 1MB + ("grpc.max_send_message_length", 1024 * 1024), # 1MB ] ) id = c.schedule_new_orchestration(empty_orchestrator) From 02ef910198e93998a9fe0c425cb388f25276342a Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Fri, 31 Oct 2025 10:40:03 -0500 Subject: [PATCH 04/10] ruff Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- durabletask/aio/internal/shared.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/durabletask/aio/internal/shared.py b/durabletask/aio/internal/shared.py index 113f73b..d69ba9d 100644 --- a/durabletask/aio/internal/shared.py +++ b/durabletask/aio/internal/shared.py @@ -1,7 +1,7 @@ # Copyright (c) The Dapr Authors. # Licensed under the MIT License. -from typing import Any, Dict, Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Union import grpc from grpc import aio as grpc_aio @@ -52,7 +52,9 @@ def get_grpc_aio_channel( break # channel interceptors/options - channel_kwargs: Dict[str, ChannelArgumentType | Sequence[ClientInterceptor]] = dict(interceptors=interceptors) + channel_kwargs: Dict[str, ChannelArgumentType | Sequence[ClientInterceptor]] = dict( + interceptors=interceptors + ) if options is not None: validate_grpc_options(options) channel_kwargs["options"] = options From 76444c822f03ce41ce97b1578620a5e0c5cca291 Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Mon, 3 Nov 2025 12:28:58 -0600 Subject: [PATCH 05/10] tackle feedback Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- durabletask/aio/internal/shared.py | 13 ++--- durabletask/internal/shared.py | 14 ++--- tests/durabletask/test_client.py | 59 +++++++++++++++----- tests/durabletask/test_client_async.py | 77 +++++++++++++++++++------- 4 files changed, 111 insertions(+), 52 deletions(-) diff --git a/durabletask/aio/internal/shared.py b/durabletask/aio/internal/shared.py index d69ba9d..3825fe6 100644 --- a/durabletask/aio/internal/shared.py +++ b/durabletask/aio/internal/shared.py @@ -1,7 +1,7 @@ # Copyright (c) The Dapr Authors. # Licensed under the MIT License. -from typing import Dict, Optional, Sequence, Union +from typing import Optional, Sequence, Union import grpc from grpc import aio as grpc_aio @@ -51,19 +51,16 @@ def get_grpc_aio_channel( host_address = host_address[len(protocol) :] break - # channel interceptors/options - channel_kwargs: Dict[str, ChannelArgumentType | Sequence[ClientInterceptor]] = dict( - interceptors=interceptors - ) if options is not None: validate_grpc_options(options) - channel_kwargs["options"] = options if secure_channel: channel = grpc_aio.secure_channel( - host_address, grpc.ssl_channel_credentials(), **channel_kwargs + host_address, grpc.ssl_channel_credentials(), interceptors=interceptors, options=options ) else: - channel = grpc_aio.insecure_channel(host_address, **channel_kwargs) + channel = grpc_aio.insecure_channel( + host_address, interceptors=interceptors, options=options + ) return channel diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 9c7f111..d971f1d 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -97,17 +97,11 @@ def get_grpc_channel( if options is not None: # validate all options keys prefix starts with `grpc.` validate_grpc_options(options) - if secure_channel: - channel = grpc.secure_channel( - host_address, grpc.ssl_channel_credentials(), options=options - ) - else: - channel = grpc.insecure_channel(host_address, options=options) + + if secure_channel: + channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials(), options=options) else: - if secure_channel: - channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials()) - else: - channel = grpc.insecure_channel(host_address) + channel = grpc.insecure_channel(host_address, options=options) # Apply interceptors ONLY if they exist if interceptors: diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index d55e0e0..7f61c2f 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -1,4 +1,4 @@ -from unittest.mock import ANY, patch +from unittest.mock import patch from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl from durabletask.internal.shared import get_default_host_address, get_grpc_channel @@ -11,7 +11,9 @@ def test_get_grpc_channel_insecure(): with patch("grpc.insecure_channel") as mock_channel: get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) - mock_channel.assert_called_once_with(HOST_ADDRESS) + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert "options" in kwargs and kwargs["options"] is None def test_get_grpc_channel_secure(): @@ -20,13 +22,18 @@ def test_get_grpc_channel_secure(): patch("grpc.ssl_channel_credentials") as mock_credentials, ): get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS) - mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value) + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert args[1] == mock_credentials.return_value + assert "options" in kwargs and kwargs["options"] is None def test_get_grpc_channel_default_host_address(): with patch("grpc.insecure_channel") as mock_channel: get_grpc_channel(None, False, interceptors=INTERCEPTORS) - mock_channel.assert_called_once_with(get_default_host_address()) + args, kwargs = mock_channel.call_args + assert args[0] == get_default_host_address() + assert "options" in kwargs and kwargs["options"] is None def test_get_grpc_channel_with_metadata(): @@ -35,7 +42,9 @@ def test_get_grpc_channel_with_metadata(): patch("grpc.intercept_channel") as mock_intercept_channel, ): get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) - mock_channel.assert_called_once_with(HOST_ADDRESS) + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert "options" in kwargs and kwargs["options"] is None mock_intercept_channel.assert_called_once() # Capture and check the arguments passed to intercept_channel() @@ -54,40 +63,60 @@ def test_grpc_channel_with_host_name_protocol_stripping(): prefix = "grpc://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_insecure_channel.assert_called_with(host_name) + args, kwargs = mock_insecure_channel.call_args + assert args[0] == host_name + assert "options" in kwargs and kwargs["options"] is None prefix = "http://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_insecure_channel.assert_called_with(host_name) + args, kwargs = mock_insecure_channel.call_args + assert args[0] == host_name + assert "options" in kwargs and kwargs["options"] is None prefix = "HTTP://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_insecure_channel.assert_called_with(host_name) + args, kwargs = mock_insecure_channel.call_args + assert args[0] == host_name + assert "options" in kwargs and kwargs["options"] is None prefix = "GRPC://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_insecure_channel.assert_called_with(host_name) + args, kwargs = mock_insecure_channel.call_args + assert args[0] == host_name + assert "options" in kwargs and kwargs["options"] is None prefix = "" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_insecure_channel.assert_called_with(host_name) + args, kwargs = mock_insecure_channel.call_args + assert args[0] == host_name + assert "options" in kwargs and kwargs["options"] is None prefix = "grpcs://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_secure_channel.assert_called_with(host_name, ANY) + args, kwargs = mock_secure_channel.call_args + assert args[0] == host_name + assert "options" in kwargs and kwargs["options"] is None prefix = "https://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_secure_channel.assert_called_with(host_name, ANY) + args, kwargs = mock_secure_channel.call_args + assert args[0] == host_name + assert "options" in kwargs and kwargs["options"] is None prefix = "HTTPS://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_secure_channel.assert_called_with(host_name, ANY) + args, kwargs = mock_secure_channel.call_args + assert args[0] == host_name + assert "options" in kwargs and kwargs["options"] is None prefix = "GRPCS://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_secure_channel.assert_called_with(host_name, ANY) + args, kwargs = mock_secure_channel.call_args + assert args[0] == host_name + assert "options" in kwargs and kwargs["options"] is None prefix = "" get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS) - mock_secure_channel.assert_called_with(host_name, ANY) + args, kwargs = mock_secure_channel.call_args + assert args[0] == host_name + assert "options" in kwargs and kwargs["options"] is None diff --git a/tests/durabletask/test_client_async.py b/tests/durabletask/test_client_async.py index 0588ff1..9b6dfc3 100644 --- a/tests/durabletask/test_client_async.py +++ b/tests/durabletask/test_client_async.py @@ -1,7 +1,7 @@ # Copyright (c) The Dapr Authors. # Licensed under the MIT License. -from unittest.mock import ANY, patch +from unittest.mock import patch from durabletask.aio.client import AsyncTaskHubGrpcClient from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl @@ -16,7 +16,10 @@ def test_get_grpc_aio_channel_insecure(): with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel: get_grpc_aio_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS_AIO) - mock_channel.assert_called_once_with(HOST_ADDRESS, interceptors=INTERCEPTORS_AIO) + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None def test_get_grpc_aio_channel_secure(): @@ -25,23 +28,29 @@ def test_get_grpc_aio_channel_secure(): patch("grpc.ssl_channel_credentials") as mock_credentials, ): get_grpc_aio_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS_AIO) - mock_channel.assert_called_once_with( - HOST_ADDRESS, mock_credentials.return_value, interceptors=INTERCEPTORS_AIO - ) + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert args[1] == mock_credentials.return_value + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None def test_get_grpc_aio_channel_default_host_address(): with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel: get_grpc_aio_channel(None, False, interceptors=INTERCEPTORS_AIO) - mock_channel.assert_called_once_with( - get_default_host_address(), interceptors=INTERCEPTORS_AIO - ) + args, kwargs = mock_channel.call_args + assert args[0] == get_default_host_address() + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None def test_get_grpc_aio_channel_with_interceptors(): with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel: get_grpc_aio_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS_AIO) - mock_channel.assert_called_once_with(HOST_ADDRESS, interceptors=INTERCEPTORS_AIO) + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None # Capture and check the arguments passed to insecure_channel() args, kwargs = mock_channel.call_args @@ -61,43 +70,73 @@ def test_grpc_aio_channel_with_host_name_protocol_stripping(): prefix = "grpc://" get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) - mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + args, kwargs = mock_insecure_channel.call_args + assert args[0] == host_name + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None prefix = "http://" get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) - mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + args, kwargs = mock_insecure_channel.call_args + assert args[0] == host_name + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None prefix = "HTTP://" get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) - mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + args, kwargs = mock_insecure_channel.call_args + assert args[0] == host_name + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None prefix = "GRPC://" get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) - mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + args, kwargs = mock_insecure_channel.call_args + assert args[0] == host_name + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None prefix = "" get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) - mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + args, kwargs = mock_insecure_channel.call_args + assert args[0] == host_name + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None prefix = "grpcs://" get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) - mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + args, kwargs = mock_secure_channel.call_args + assert args[0] == host_name + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None prefix = "https://" get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) - mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + args, kwargs = mock_secure_channel.call_args + assert args[0] == host_name + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None prefix = "HTTPS://" get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) - mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + args, kwargs = mock_secure_channel.call_args + assert args[0] == host_name + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None prefix = "GRPCS://" get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) - mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + args, kwargs = mock_secure_channel.call_args + assert args[0] == host_name + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None prefix = "" get_grpc_aio_channel(prefix + host_name, True, interceptors=INTERCEPTORS_AIO) - mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + args, kwargs = mock_secure_channel.call_args + assert args[0] == host_name + assert kwargs.get("interceptors") == INTERCEPTORS_AIO + assert "options" in kwargs and kwargs["options"] is None def test_async_client_construct_with_metadata(): From f3442db022d1b729334a9f9f8029175629861512 Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:45:21 -0600 Subject: [PATCH 06/10] add missing grpc option in worker grpc client Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- durabletask/worker.py | 4 +++- tests/durabletask/test_orchestration_e2e.py | 13 ++++++++----- .../test_orchestration_e2e_async.py | 19 ++++++++++++++----- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/durabletask/worker.py b/durabletask/worker.py index 2d057e1..b15ee98 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -223,6 +223,7 @@ def __init__( secure_channel: bool = False, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, concurrency_options: Optional[ConcurrencyOptions] = None, + channel_options: Optional[Sequence[tuple[str, Any]]] = None, ): self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() @@ -230,6 +231,7 @@ def __init__( self._shutdown = Event() self._is_running = False self._secure_channel = secure_channel + self._channel_options = channel_options # Use provided concurrency options or create default ones self._concurrency_options = ( @@ -306,7 +308,7 @@ def create_fresh_connection(): current_stub = None try: current_channel = shared.get_grpc_channel( - self._host_address, self._secure_channel, self._interceptors + self._host_address, self._secure_channel, self._interceptors, options=self._channel_options ) current_stub = stubs.TaskHubSidecarServiceStub(current_channel) current_stub.Hello(empty_pb2.Empty()) diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index b60c035..08de87b 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -11,7 +11,8 @@ from durabletask import client, task, worker # NOTE: These tests assume a sidecar process is running. Example command: -# docker run --name durabletask-sidecar -p 4001:4001 --env 'DURABLETASK_SIDECAR_LOGLEVEL=Debug' --rm cgillum/durabletask-sidecar:latest start --backend Emulator +# dapr init || true +# dapr run --app-id test-app --dapr-grpc-port 4001 pytestmark = pytest.mark.e2e @@ -22,16 +23,18 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): nonlocal invoked # don't do this in a real app! invoked = True + channel_options = [ + ("grpc.max_send_message_length", 1024 * 1024), # 1MB + ] + # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(channel_options=channel_options) as w: w.add_orchestrator(empty_orchestrator) w.start() # set a custom max send length option c = client.TaskHubGrpcClient( - channel_options=[ - ("grpc.max_send_message_length", 1024 * 1024), # 1MB - ] + channel_options=channel_options ) id = c.schedule_new_orchestration(empty_orchestrator) state = c.wait_for_orchestration_completion(id, timeout=30) diff --git a/tests/durabletask/test_orchestration_e2e_async.py b/tests/durabletask/test_orchestration_e2e_async.py index 2e34603..78b7937 100644 --- a/tests/durabletask/test_orchestration_e2e_async.py +++ b/tests/durabletask/test_orchestration_e2e_async.py @@ -13,7 +13,7 @@ from durabletask.client import OrchestrationStatus # NOTE: These tests assume a sidecar process is running. Example command: -# go install github.com/microsoft/durabletask-go@main +# go install github.com/dapr/durabletask-go@main # durabletask-go --port 4001 pytestmark = [pytest.mark.e2e, pytest.mark.asyncio] @@ -25,12 +25,16 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): nonlocal invoked # don't do this in a real app! invoked = True + channel_options = [ + ("grpc.max_send_message_length", 1024 * 1024), # 1MB + ] + # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(channel_options=channel_options) as w: w.add_orchestrator(empty_orchestrator) w.start() - c = AsyncTaskHubGrpcClient() + c = AsyncTaskHubGrpcClient(channel_options=channel_options) id = await c.schedule_new_orchestration(empty_orchestrator) state = await c.wait_for_orchestration_completion(id, timeout=30) await c.aclose() @@ -58,13 +62,18 @@ def sequence(ctx: task.OrchestrationContext, start_val: int): numbers.append(current) return numbers + channel_options =[ + ("grpc.max_send_message_length", 1024 * 1024), # 1MB + ] # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker( + channel_options=channel_options + ) as w: w.add_orchestrator(sequence) w.add_activity(plus_one) w.start() - client = AsyncTaskHubGrpcClient() + client = AsyncTaskHubGrpcClient(channel_options=channel_options) id = await client.schedule_new_orchestration(sequence, input=1) state = await client.wait_for_orchestration_completion(id, timeout=30) await client.aclose() From f150b5a3029e7b3c805f42c86a92ddae61622cb1 Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Tue, 4 Nov 2025 06:59:10 -0600 Subject: [PATCH 07/10] remove validate grpc key prefix Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- durabletask/aio/internal/shared.py | 4 ---- durabletask/worker.py | 5 ++++- tests/durabletask/test_orchestration_e2e.py | 4 +--- tests/durabletask/test_orchestration_e2e_async.py | 10 ++++------ 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/durabletask/aio/internal/shared.py b/durabletask/aio/internal/shared.py index 3825fe6..cb4ffc0 100644 --- a/durabletask/aio/internal/shared.py +++ b/durabletask/aio/internal/shared.py @@ -11,7 +11,6 @@ INSECURE_PROTOCOLS, SECURE_PROTOCOLS, get_default_host_address, - validate_grpc_options, ) ClientInterceptor = Union[ @@ -51,9 +50,6 @@ def get_grpc_aio_channel( host_address = host_address[len(protocol) :] break - if options is not None: - validate_grpc_options(options) - if secure_channel: channel = grpc_aio.secure_channel( host_address, grpc.ssl_channel_credentials(), interceptors=interceptors, options=options diff --git a/durabletask/worker.py b/durabletask/worker.py index b15ee98..daa661b 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -308,7 +308,10 @@ def create_fresh_connection(): current_stub = None try: current_channel = shared.get_grpc_channel( - self._host_address, self._secure_channel, self._interceptors, options=self._channel_options + self._host_address, + self._secure_channel, + self._interceptors, + options=self._channel_options, ) current_stub = stubs.TaskHubSidecarServiceStub(current_channel) current_stub.Hello(empty_pb2.Empty()) diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 08de87b..225456d 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -33,9 +33,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.start() # set a custom max send length option - c = client.TaskHubGrpcClient( - channel_options=channel_options - ) + c = client.TaskHubGrpcClient(channel_options=channel_options) id = c.schedule_new_orchestration(empty_orchestrator) state = c.wait_for_orchestration_completion(id, timeout=30) diff --git a/tests/durabletask/test_orchestration_e2e_async.py b/tests/durabletask/test_orchestration_e2e_async.py index 78b7937..c441bdc 100644 --- a/tests/durabletask/test_orchestration_e2e_async.py +++ b/tests/durabletask/test_orchestration_e2e_async.py @@ -62,13 +62,11 @@ def sequence(ctx: task.OrchestrationContext, start_val: int): numbers.append(current) return numbers - channel_options =[ - ("grpc.max_send_message_length", 1024 * 1024), # 1MB - ] + channel_options = [ + ("grpc.max_send_message_length", 1024 * 1024), # 1MB + ] # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker( - channel_options=channel_options - ) as w: + with worker.TaskHubGrpcWorker(channel_options=channel_options) as w: w.add_orchestrator(sequence) w.add_activity(plus_one) w.start() From d7910e76da18489b9a7a41d23bfebee060066e1d Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Tue, 4 Nov 2025 07:37:09 -0600 Subject: [PATCH 08/10] include not-saved file on removal of validate grpc prefix Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- durabletask/internal/shared.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index d971f1d..34e5f73 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -51,17 +51,6 @@ def get_default_host_address() -> str: return "localhost:4001" -def validate_grpc_options(options: ChannelArgumentType): - """Validate that all gRPC options are valid. Mainly checking keys. Values can be string, int, float, bool and pointer""" - for key, value in options: - if not isinstance(key, str): - raise ValueError(f"gRPC option key must be a string. Invalid key: {key}") - if not all(key.startswith("grpc.") for key, _ in options): - raise ValueError( - f"All options keys must start with `grpc.`. Invalid options: {options}" - ) - - def get_grpc_channel( host_address: Optional[str], secure_channel: bool = False, @@ -93,11 +82,6 @@ def get_grpc_channel( host_address = host_address[len(protocol) :] break - # Create the base channel - if options is not None: - # validate all options keys prefix starts with `grpc.` - validate_grpc_options(options) - if secure_channel: channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials(), options=options) else: From 7980fd86b396b86ed36a94f5a5bdc75f635ebd10 Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Tue, 4 Nov 2025 07:50:04 -0600 Subject: [PATCH 09/10] ruff/lint Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- durabletask/internal/shared.py | 1 - 1 file changed, 1 deletion(-) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 34e5f73..3adb6b1 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -9,7 +9,6 @@ from typing import Any, Optional, Sequence, Union import grpc -from grpc.aio import ChannelArgumentType ClientInterceptor = Union[ grpc.UnaryUnaryClientInterceptor, From 5cddf4fc165e5718a793fd2e4be9f559259d7f00 Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Tue, 4 Nov 2025 09:19:54 -0600 Subject: [PATCH 10/10] consolidate some options tests Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- tests/durabletask/test_client.py | 20 ++++ tests/durabletask/test_client_async.py | 20 ++++ .../test_grpc_aio_channel_options.py | 94 ------------------- .../durabletask/test_grpc_channel_options.py | 81 ---------------- 4 files changed, 40 insertions(+), 175 deletions(-) delete mode 100644 tests/durabletask/test_grpc_aio_channel_options.py delete mode 100644 tests/durabletask/test_grpc_channel_options.py diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index 7f61c2f..b671cf8 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -120,3 +120,23 @@ def test_grpc_channel_with_host_name_protocol_stripping(): args, kwargs = mock_secure_channel.call_args assert args[0] == host_name assert "options" in kwargs and kwargs["options"] is None + + +def test_sync_channel_passes_base_options_and_max_lengths(): + base_options = [ + ("grpc.max_send_message_length", 1234), + ("grpc.max_receive_message_length", 5678), + ("grpc.primary_user_agent", "durabletask-tests"), + ] + with patch("grpc.insecure_channel") as mock_channel: + get_grpc_channel(HOST_ADDRESS, False, options=base_options) + # Ensure called with options kwarg + assert mock_channel.call_count == 1 + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert "options" in kwargs + opts = kwargs["options"] + # Check our base options made it through + assert ("grpc.max_send_message_length", 1234) in opts + assert ("grpc.max_receive_message_length", 5678) in opts + assert ("grpc.primary_user_agent", "durabletask-tests") in opts diff --git a/tests/durabletask/test_client_async.py b/tests/durabletask/test_client_async.py index 9b6dfc3..43e8870 100644 --- a/tests/durabletask/test_client_async.py +++ b/tests/durabletask/test_client_async.py @@ -149,3 +149,23 @@ def test_async_client_construct_with_metadata(): interceptors = kwargs["interceptors"] assert isinstance(interceptors[0], DefaultClientInterceptorImpl) assert interceptors[0]._metadata == METADATA + + +def test_aio_channel_passes_base_options_and_max_lengths(): + base_options = [ + ("grpc.max_send_message_length", 4321), + ("grpc.max_receive_message_length", 8765), + ("grpc.primary_user_agent", "durabletask-aio-tests"), + ] + with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel: + get_grpc_aio_channel(HOST_ADDRESS, False, options=base_options) + # Ensure called with options kwarg + assert mock_channel.call_count == 1 + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert "options" in kwargs + opts = kwargs["options"] + # Check our base options made it through + assert ("grpc.max_send_message_length", 4321) in opts + assert ("grpc.max_receive_message_length", 8765) in opts + assert ("grpc.primary_user_agent", "durabletask-aio-tests") in opts diff --git a/tests/durabletask/test_grpc_aio_channel_options.py b/tests/durabletask/test_grpc_aio_channel_options.py deleted file mode 100644 index 2f64577..0000000 --- a/tests/durabletask/test_grpc_aio_channel_options.py +++ /dev/null @@ -1,94 +0,0 @@ -import json -from unittest.mock import patch - -import pytest - -from durabletask.aio.internal.shared import get_grpc_aio_channel - -HOST_ADDRESS = "localhost:50051" - - -def _find_option(options, key): - for k, v in options: - if k == key: - return v - raise AssertionError(f"Option with key {key} not found in options: {options}") - - -def test_aio_channel_passes_base_options_and_max_lengths(): - base_options = [ - ("grpc.max_send_message_length", 4321), - ("grpc.max_receive_message_length", 8765), - ("grpc.primary_user_agent", "durabletask-aio-tests"), - ] - with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel: - get_grpc_aio_channel(HOST_ADDRESS, False, options=base_options) - # Ensure called with options kwarg - assert mock_channel.call_count == 1 - args, kwargs = mock_channel.call_args - assert args[0] == HOST_ADDRESS - assert "options" in kwargs - opts = kwargs["options"] - # Check our base options made it through - assert ("grpc.max_send_message_length", 4321) in opts - assert ("grpc.max_receive_message_length", 8765) in opts - assert ("grpc.primary_user_agent", "durabletask-aio-tests") in opts - - -def test_aio_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyPatch): - # retry grpc option - # service_config ref => https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto#L44 - max_attempts = 4 - initial_backoff_ms = 250 - max_backoff_ms = 2000 - backoff_multiplier = 1.5 - codes = ["RESOURCE_EXHAUSTED"] - service_config = { - "methodConfig": [ - { - "name": [{"service": ""}], # match all services/methods - "retryPolicy": { - "maxAttempts": max_attempts, - "initialBackoff": f"{initial_backoff_ms / 1000.0}s", - "maxBackoff": f"{max_backoff_ms / 1000.0}s", - "backoffMultiplier": backoff_multiplier, - "retryableStatusCodes": codes, - }, - } - ] - } - - base_options = [("grpc.service_config", json.dumps(service_config))] - - with patch("durabletask.aio.internal.shared.grpc_aio.insecure_channel") as mock_channel: - get_grpc_aio_channel(HOST_ADDRESS, False, options=base_options) - - args, kwargs = mock_channel.call_args - assert args[0] == HOST_ADDRESS - assert "options" in kwargs - opts = kwargs["options"] - - # Retry service config present and parses correctly - svc_cfg_str = _find_option(opts, "grpc.service_config") - svc_cfg = json.loads(svc_cfg_str) - assert "methodConfig" in svc_cfg and isinstance(svc_cfg["methodConfig"], list) - retry_policy = svc_cfg["methodConfig"][0]["retryPolicy"] - assert retry_policy["maxAttempts"] == 4 - assert retry_policy["initialBackoff"] == f"{250 / 1000.0}s" - assert retry_policy["maxBackoff"] == f"{2000 / 1000.0}s" - assert retry_policy["backoffMultiplier"] == 1.5 - # Codes are upper-cased list - assert "RESOURCE_EXHAUSTED" in retry_policy["retryableStatusCodes"] - - -def test_aio_secure_channel_receives_options_when_secure_true(): - base_options = [("grpc.max_receive_message_length", 999999)] - with ( - patch("durabletask.aio.internal.shared.grpc_aio.secure_channel") as mock_channel, - patch("grpc.ssl_channel_credentials") as mock_credentials, - ): - get_grpc_aio_channel(HOST_ADDRESS, True, options=base_options) - args, kwargs = mock_channel.call_args - assert args[0] == HOST_ADDRESS - assert args[1] == mock_credentials.return_value - assert ("grpc.max_receive_message_length", 999999) in kwargs.get("options", []) diff --git a/tests/durabletask/test_grpc_channel_options.py b/tests/durabletask/test_grpc_channel_options.py deleted file mode 100644 index 841d75b..0000000 --- a/tests/durabletask/test_grpc_channel_options.py +++ /dev/null @@ -1,81 +0,0 @@ -import json -from unittest.mock import patch - -import pytest - -from durabletask.internal.shared import get_grpc_channel - -HOST_ADDRESS = "localhost:50051" - - -def _find_option(options, key): - for k, v in options: - if k == key: - return v - raise AssertionError(f"Option with key {key} not found in options: {options}") - - -def test_sync_channel_passes_base_options_and_max_lengths(): - base_options = [ - ("grpc.max_send_message_length", 1234), - ("grpc.max_receive_message_length", 5678), - ("grpc.primary_user_agent", "durabletask-tests"), - ] - with patch("grpc.insecure_channel") as mock_channel: - get_grpc_channel(HOST_ADDRESS, False, options=base_options) - # Ensure called with options kwarg - assert mock_channel.call_count == 1 - args, kwargs = mock_channel.call_args - assert args[0] == HOST_ADDRESS - assert "options" in kwargs - opts = kwargs["options"] - # Check our base options made it through - assert ("grpc.max_send_message_length", 1234) in opts - assert ("grpc.max_receive_message_length", 5678) in opts - assert ("grpc.primary_user_agent", "durabletask-tests") in opts - - -def test_sync_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyPatch): - # retry grpc option - # service_config ref => https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto#L44 - max_attempts = 4 - initial_backoff_ms = 250 - max_backoff_ms = 2000 - backoff_multiplier = 1.5 - codes = ["ABORTED"] - service_config = { - "methodConfig": [ - { - "name": [{"service": ""}], # match all services/methods - "retryPolicy": { - "maxAttempts": max_attempts, - "initialBackoff": f"{initial_backoff_ms / 1000.0}s", - "maxBackoff": f"{max_backoff_ms / 1000.0}s", - "backoffMultiplier": backoff_multiplier, - "retryableStatusCodes": codes, - }, - } - ] - } - - base_options = [("grpc.service_config", json.dumps(service_config))] - - with patch("grpc.insecure_channel") as mock_channel: - get_grpc_channel(HOST_ADDRESS, False, options=base_options) - - args, kwargs = mock_channel.call_args - assert args[0] == HOST_ADDRESS - assert "options" in kwargs - opts = kwargs["options"] - - # Retry service config present and parses correctly - svc_cfg_str = _find_option(opts, "grpc.service_config") - svc_cfg = json.loads(svc_cfg_str) - assert "methodConfig" in svc_cfg and isinstance(svc_cfg["methodConfig"], list) - retry_policy = svc_cfg["methodConfig"][0]["retryPolicy"] - assert retry_policy["maxAttempts"] == 4 - assert retry_policy["initialBackoff"] == f"{250 / 1000.0}s" - assert retry_policy["maxBackoff"] == f"{2000 / 1000.0}s" - assert retry_policy["backoffMultiplier"] == 1.5 - # Codes are upper-cased list - assert "ABORTED" in retry_policy["retryableStatusCodes"]