Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions sdk/src/opendecree/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from typing import Any

import grpc

_DEFAULT_KEEPALIVE_TIME_MS = 30000
Expand Down Expand Up @@ -99,6 +101,7 @@ def create_aio_channel(
insecure: bool = True,
credentials: grpc.ChannelCredentials | None = None,
token: str | None = None,
interceptors: list[Any] | None = None,
max_send_message_length: int | None = None,
max_recv_message_length: int | None = None,
keepalive_time_ms: int = _DEFAULT_KEEPALIVE_TIME_MS,
Expand Down Expand Up @@ -127,6 +130,7 @@ def create_aio_channel(
reconnect_backoff_initial_ms,
reconnect_backoff_max_ms,
)
aio_interceptors = tuple(interceptors) if interceptors else ()

channel_creds: grpc.ChannelCredentials | None = credentials
if channel_creds is None and not insecure:
Expand All @@ -137,6 +141,8 @@ def create_aio_channel(
channel_creds = grpc.composite_channel_credentials(
channel_creds, _token_call_credentials(token)
)
return grpc.aio.secure_channel(target, channel_creds, options=options)
return grpc.aio.secure_channel(
target, channel_creds, options=options, interceptors=aio_interceptors
)

return grpc.aio.insecure_channel(target, options=options)
return grpc.aio.insecure_channel(target, options=options, interceptors=aio_interceptors)
12 changes: 10 additions & 2 deletions sdk/src/opendecree/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import warnings
from datetime import timedelta
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, Any, overload

if TYPE_CHECKING:
from opendecree.async_watcher import AsyncConfigWatcher
Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(
timeout: float = 10.0,
retry: RetryConfig | None = None,
check_version: bool = False,
interceptors: list[Any] | None = None,
) -> None:
"""Create a new AsyncConfigClient.

Expand All @@ -76,6 +77,9 @@ def __init__(
check_version: When True, run :meth:`check_compatibility` lazily
on the first RPC call. Raises :exc:`IncompatibleServerError`
if the server version is outside the supported range.
interceptors: Optional list of :class:`grpc.aio.ClientInterceptor`
instances to inject (e.g., for logging, tracing, or metrics).
Passed directly to the ``grpc.aio`` channel.
"""
self._timeout = timeout
self._retry = retry if retry is not None else RetryConfig()
Expand All @@ -102,7 +106,11 @@ def __init__(
subject=subject, role=role, tenant_id=tenant_id, token=metadata_token
)
self._channel = create_aio_channel(
target, insecure=insecure, credentials=credentials, token=channel_token
target,
insecure=insecure,
credentials=credentials,
token=channel_token,
interceptors=interceptors,
)

cs_pb2, cs_grpc = ensure_stubs()
Expand Down
18 changes: 13 additions & 5 deletions sdk/src/opendecree/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import warnings
from datetime import timedelta
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, Any, overload

if TYPE_CHECKING:
from opendecree.watcher import ConfigWatcher
Expand Down Expand Up @@ -57,6 +57,7 @@ def __init__(
timeout: float = 10.0,
retry: RetryConfig | None = None,
check_version: bool = False,
interceptors: list[Any] | None = None,
) -> None:
"""Create a new ConfigClient.

Expand All @@ -79,6 +80,12 @@ def __init__(
check_version: When True, run :meth:`check_compatibility` lazily
on the first RPC call. Raises :exc:`IncompatibleServerError`
if the server version is outside the supported range.
interceptors: Optional list of
:class:`grpc.UnaryUnaryClientInterceptor` /
:class:`grpc.UnaryStreamClientInterceptor` instances to inject
(e.g., for logging, tracing, or metrics). User-supplied
interceptors are applied outermost (before the SDK's internal
auth interceptor).
"""
self._timeout = timeout
self._retry = retry if retry is not None else RetryConfig()
Expand All @@ -102,15 +109,16 @@ def __init__(
metadata = _build_metadata(
subject=subject, role=role, tenant_id=tenant_id, token=metadata_token
)
interceptors: list[grpc.UnaryUnaryClientInterceptor] = []
# User interceptors are outermost; auth interceptor runs inside them.
all_interceptors: list[Any] = list(interceptors) if interceptors else []
if metadata:
interceptors.append(AuthInterceptor(metadata))
all_interceptors.append(AuthInterceptor(metadata))

channel = create_channel(
target, insecure=insecure, credentials=credentials, token=channel_token
)
if interceptors:
self._channel = grpc.intercept_channel(channel, *interceptors)
if all_interceptors:
self._channel = grpc.intercept_channel(channel, *all_interceptors)
else:
self._channel = channel
self._raw_channel = channel # keep ref for close()
Expand Down
21 changes: 21 additions & 0 deletions sdk/tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,24 @@ async def _empty_stream():
ctx = client.watch("t1")
async with ctx as watcher:
assert watcher is not None

def test_custom_interceptors_passed_to_channel(self):
custom = MagicMock()
with patch("opendecree.async_client.create_aio_channel") as mock_ch:
mock_ch.return_value = MagicMock()
AsyncConfigClient("localhost:9090", interceptors=[custom])
assert mock_ch.call_args.kwargs["interceptors"] == [custom]

def test_no_interceptors_by_default(self):
with patch("opendecree.async_client.create_aio_channel") as mock_ch:
mock_ch.return_value = MagicMock()
AsyncConfigClient("localhost:9090")
assert mock_ch.call_args.kwargs.get("interceptors") is None

def test_multiple_custom_interceptors_preserved(self):
a = MagicMock()
b = MagicMock()
with patch("opendecree.async_client.create_aio_channel") as mock_ch:
mock_ch.return_value = MagicMock()
AsyncConfigClient("localhost:9090", interceptors=[a, b])
assert mock_ch.call_args.kwargs["interceptors"] == [a, b]
40 changes: 40 additions & 0 deletions sdk/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,43 @@ def test_watch_returns_context(self):
assert ctx is not None
with ctx as watcher:
assert watcher is not None

def test_custom_interceptors_passed_to_intercept_channel(self):
custom = MagicMock(spec=grpc.UnaryUnaryClientInterceptor)
with patch("opendecree.client.create_channel") as mock_ch:
mock_channel = MagicMock()
mock_ch.return_value = mock_channel
with patch("opendecree.client.grpc.intercept_channel") as mock_intercept:
mock_intercept.return_value = mock_channel
opendecree.ConfigClient("localhost:9090", interceptors=[custom])
args = mock_intercept.call_args[0]
assert custom in args

def test_custom_interceptors_outermost(self):
"""User interceptors must come before AuthInterceptor."""
from opendecree._interceptors import AuthInterceptor

custom = MagicMock(spec=grpc.UnaryUnaryClientInterceptor)
with patch("opendecree.client.create_channel") as mock_ch:
mock_channel = MagicMock()
mock_ch.return_value = mock_channel
with patch("opendecree.client.grpc.intercept_channel") as mock_intercept:
mock_intercept.return_value = mock_channel
opendecree.ConfigClient("localhost:9090", subject="s", interceptors=[custom])
args = mock_intercept.call_args[0]
# args[0] is the channel; args[1:] are interceptors in order
interceptors_in_order = args[1:]
assert interceptors_in_order[0] is custom
assert isinstance(interceptors_in_order[1], AuthInterceptor)

def test_custom_interceptors_no_auth(self):
"""Custom interceptors work even when no auth metadata is set."""
custom = MagicMock(spec=grpc.UnaryUnaryClientInterceptor)
with patch("opendecree.client.create_channel") as mock_ch:
mock_channel = MagicMock()
mock_ch.return_value = mock_channel
with patch("opendecree.client.grpc.intercept_channel") as mock_intercept:
mock_intercept.return_value = mock_channel
opendecree.ConfigClient("localhost:9090", role="", interceptors=[custom])
args = mock_intercept.call_args[0]
assert args[1] is custom