diff --git a/sdk/src/opendecree/_channel.py b/sdk/src/opendecree/_channel.py index 06c23c1..f741e77 100644 --- a/sdk/src/opendecree/_channel.py +++ b/sdk/src/opendecree/_channel.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Any + import grpc _DEFAULT_KEEPALIVE_TIME_MS = 30000 @@ -99,6 +101,7 @@ def create_aio_channel( insecure: bool = True, credentials: grpc.ChannelCredentials | None = None, token: str | None = None, + interceptors: list[Any] | None = None, max_send_message_length: int | None = None, max_recv_message_length: int | None = None, keepalive_time_ms: int = _DEFAULT_KEEPALIVE_TIME_MS, @@ -127,6 +130,7 @@ def create_aio_channel( reconnect_backoff_initial_ms, reconnect_backoff_max_ms, ) + aio_interceptors = tuple(interceptors) if interceptors else () channel_creds: grpc.ChannelCredentials | None = credentials if channel_creds is None and not insecure: @@ -137,6 +141,8 @@ def create_aio_channel( channel_creds = grpc.composite_channel_credentials( channel_creds, _token_call_credentials(token) ) - return grpc.aio.secure_channel(target, channel_creds, options=options) + return grpc.aio.secure_channel( + target, channel_creds, options=options, interceptors=aio_interceptors + ) - return grpc.aio.insecure_channel(target, options=options) + return grpc.aio.insecure_channel(target, options=options, interceptors=aio_interceptors) diff --git a/sdk/src/opendecree/async_client.py b/sdk/src/opendecree/async_client.py index 3f3ba7b..c44878f 100644 --- a/sdk/src/opendecree/async_client.py +++ b/sdk/src/opendecree/async_client.py @@ -9,7 +9,7 @@ import warnings from datetime import timedelta -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Any, overload if TYPE_CHECKING: from opendecree.async_watcher import AsyncConfigWatcher @@ -54,6 +54,7 @@ def __init__( timeout: float = 10.0, retry: RetryConfig | None = None, check_version: bool = False, + interceptors: list[Any] | None = None, ) -> None: """Create a new AsyncConfigClient. @@ -76,6 +77,9 @@ def __init__( check_version: When True, run :meth:`check_compatibility` lazily on the first RPC call. Raises :exc:`IncompatibleServerError` if the server version is outside the supported range. + interceptors: Optional list of :class:`grpc.aio.ClientInterceptor` + instances to inject (e.g., for logging, tracing, or metrics). + Passed directly to the ``grpc.aio`` channel. """ self._timeout = timeout self._retry = retry if retry is not None else RetryConfig() @@ -102,7 +106,11 @@ def __init__( subject=subject, role=role, tenant_id=tenant_id, token=metadata_token ) self._channel = create_aio_channel( - target, insecure=insecure, credentials=credentials, token=channel_token + target, + insecure=insecure, + credentials=credentials, + token=channel_token, + interceptors=interceptors, ) cs_pb2, cs_grpc = ensure_stubs() diff --git a/sdk/src/opendecree/client.py b/sdk/src/opendecree/client.py index bd29b28..e32c813 100644 --- a/sdk/src/opendecree/client.py +++ b/sdk/src/opendecree/client.py @@ -13,7 +13,7 @@ import warnings from datetime import timedelta -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Any, overload if TYPE_CHECKING: from opendecree.watcher import ConfigWatcher @@ -57,6 +57,7 @@ def __init__( timeout: float = 10.0, retry: RetryConfig | None = None, check_version: bool = False, + interceptors: list[Any] | None = None, ) -> None: """Create a new ConfigClient. @@ -79,6 +80,12 @@ def __init__( check_version: When True, run :meth:`check_compatibility` lazily on the first RPC call. Raises :exc:`IncompatibleServerError` if the server version is outside the supported range. + interceptors: Optional list of + :class:`grpc.UnaryUnaryClientInterceptor` / + :class:`grpc.UnaryStreamClientInterceptor` instances to inject + (e.g., for logging, tracing, or metrics). User-supplied + interceptors are applied outermost (before the SDK's internal + auth interceptor). """ self._timeout = timeout self._retry = retry if retry is not None else RetryConfig() @@ -102,15 +109,16 @@ def __init__( metadata = _build_metadata( subject=subject, role=role, tenant_id=tenant_id, token=metadata_token ) - interceptors: list[grpc.UnaryUnaryClientInterceptor] = [] + # User interceptors are outermost; auth interceptor runs inside them. + all_interceptors: list[Any] = list(interceptors) if interceptors else [] if metadata: - interceptors.append(AuthInterceptor(metadata)) + all_interceptors.append(AuthInterceptor(metadata)) channel = create_channel( target, insecure=insecure, credentials=credentials, token=channel_token ) - if interceptors: - self._channel = grpc.intercept_channel(channel, *interceptors) + if all_interceptors: + self._channel = grpc.intercept_channel(channel, *all_interceptors) else: self._channel = channel self._raw_channel = channel # keep ref for close() diff --git a/sdk/tests/test_async_client.py b/sdk/tests/test_async_client.py index 9dd1714..83c35d6 100644 --- a/sdk/tests/test_async_client.py +++ b/sdk/tests/test_async_client.py @@ -296,3 +296,24 @@ async def _empty_stream(): ctx = client.watch("t1") async with ctx as watcher: assert watcher is not None + + def test_custom_interceptors_passed_to_channel(self): + custom = MagicMock() + with patch("opendecree.async_client.create_aio_channel") as mock_ch: + mock_ch.return_value = MagicMock() + AsyncConfigClient("localhost:9090", interceptors=[custom]) + assert mock_ch.call_args.kwargs["interceptors"] == [custom] + + def test_no_interceptors_by_default(self): + with patch("opendecree.async_client.create_aio_channel") as mock_ch: + mock_ch.return_value = MagicMock() + AsyncConfigClient("localhost:9090") + assert mock_ch.call_args.kwargs.get("interceptors") is None + + def test_multiple_custom_interceptors_preserved(self): + a = MagicMock() + b = MagicMock() + with patch("opendecree.async_client.create_aio_channel") as mock_ch: + mock_ch.return_value = MagicMock() + AsyncConfigClient("localhost:9090", interceptors=[a, b]) + assert mock_ch.call_args.kwargs["interceptors"] == [a, b] diff --git a/sdk/tests/test_client.py b/sdk/tests/test_client.py index 99ee4ee..d75ed3c 100644 --- a/sdk/tests/test_client.py +++ b/sdk/tests/test_client.py @@ -338,3 +338,43 @@ def test_watch_returns_context(self): assert ctx is not None with ctx as watcher: assert watcher is not None + + def test_custom_interceptors_passed_to_intercept_channel(self): + custom = MagicMock(spec=grpc.UnaryUnaryClientInterceptor) + with patch("opendecree.client.create_channel") as mock_ch: + mock_channel = MagicMock() + mock_ch.return_value = mock_channel + with patch("opendecree.client.grpc.intercept_channel") as mock_intercept: + mock_intercept.return_value = mock_channel + opendecree.ConfigClient("localhost:9090", interceptors=[custom]) + args = mock_intercept.call_args[0] + assert custom in args + + def test_custom_interceptors_outermost(self): + """User interceptors must come before AuthInterceptor.""" + from opendecree._interceptors import AuthInterceptor + + custom = MagicMock(spec=grpc.UnaryUnaryClientInterceptor) + with patch("opendecree.client.create_channel") as mock_ch: + mock_channel = MagicMock() + mock_ch.return_value = mock_channel + with patch("opendecree.client.grpc.intercept_channel") as mock_intercept: + mock_intercept.return_value = mock_channel + opendecree.ConfigClient("localhost:9090", subject="s", interceptors=[custom]) + args = mock_intercept.call_args[0] + # args[0] is the channel; args[1:] are interceptors in order + interceptors_in_order = args[1:] + assert interceptors_in_order[0] is custom + assert isinstance(interceptors_in_order[1], AuthInterceptor) + + def test_custom_interceptors_no_auth(self): + """Custom interceptors work even when no auth metadata is set.""" + custom = MagicMock(spec=grpc.UnaryUnaryClientInterceptor) + with patch("opendecree.client.create_channel") as mock_ch: + mock_channel = MagicMock() + mock_ch.return_value = mock_channel + with patch("opendecree.client.grpc.intercept_channel") as mock_intercept: + mock_intercept.return_value = mock_channel + opendecree.ConfigClient("localhost:9090", role="", interceptors=[custom]) + args = mock_intercept.call_args[0] + assert args[1] is custom