Skip to content

Commit 6561cfa

Browse files
feat: add basic interceptor to client (#1206)
1 parent 72dfdc4 commit 6561cfa

File tree

9 files changed

+522
-30
lines changed

9 files changed

+522
-30
lines changed

google/cloud/bigtable/data/_async/client.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
cast,
2020
Any,
2121
AsyncIterable,
22+
Callable,
2223
Optional,
2324
Set,
2425
Sequence,
@@ -99,18 +100,24 @@
99100
)
100101
from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE
101102
from google.cloud.bigtable.data._async._swappable_channel import (
102-
AsyncSwappableChannel,
103+
AsyncSwappableChannel as SwappableChannelType,
104+
)
105+
from google.cloud.bigtable.data._async.metrics_interceptor import (
106+
AsyncBigtableMetricsInterceptor as MetricsInterceptorType,
103107
)
104108
else:
105109
from typing import Iterable # noqa: F401
106110
from grpc import insecure_channel
111+
from grpc import intercept_channel
107112
from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore
108113
from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient # type: ignore
109114
from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE
110115
from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( # noqa: F401
111-
SwappableChannel,
116+
SwappableChannel as SwappableChannelType,
117+
)
118+
from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( # noqa: F401
119+
BigtableMetricsInterceptor as MetricsInterceptorType,
112120
)
113-
114121

115122
if TYPE_CHECKING:
116123
from google.cloud.bigtable.data._helpers import RowKeySamples
@@ -205,7 +212,7 @@ def __init__(
205212
credentials = google.auth.credentials.AnonymousCredentials()
206213
if project is None:
207214
project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT
208-
215+
self._metrics_interceptor = MetricsInterceptorType()
209216
# initialize client
210217
ClientWithProject.__init__(
211218
self,
@@ -259,12 +266,11 @@ def __init__(
259266
stacklevel=2,
260267
)
261268

262-
@CrossSync.convert(replace_symbols={"AsyncSwappableChannel": "SwappableChannel"})
263-
def _build_grpc_channel(self, *args, **kwargs) -> AsyncSwappableChannel:
269+
def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannelType:
264270
"""
265271
This method is called by the gapic transport to create a grpc channel.
266272
267-
The init arguments passed down are captured in a partial used by AsyncSwappableChannel
273+
The init arguments passed down are captured in a partial used by SwappableChannel
268274
to create new channel instances in the future, as part of the channel refresh logic
269275
270276
Emulators always use an inseucre channel
@@ -275,12 +281,30 @@ def _build_grpc_channel(self, *args, **kwargs) -> AsyncSwappableChannel:
275281
Returns:
276282
a custom wrapped swappable channel
277283
"""
284+
create_channel_fn: Callable[[], Channel]
278285
if self._emulator_host is not None:
279-
# emulators use insecure channel
286+
# Emulators use insecure channels
280287
create_channel_fn = partial(insecure_channel, self._emulator_host)
281-
else:
288+
elif CrossSync.is_async:
289+
# For async client, use the default create_channel.
282290
create_channel_fn = partial(TransportType.create_channel, *args, **kwargs)
283-
return AsyncSwappableChannel(create_channel_fn)
291+
else:
292+
# For sync client, wrap create_channel with interceptors.
293+
def sync_create_channel_fn():
294+
return intercept_channel(
295+
TransportType.create_channel(*args, **kwargs),
296+
self._metrics_interceptor,
297+
)
298+
299+
create_channel_fn = sync_create_channel_fn
300+
301+
# Instantiate SwappableChannelType with the determined creation function.
302+
new_channel = SwappableChannelType(create_channel_fn)
303+
if CrossSync.is_async:
304+
# Attach async interceptors to the channel instance itself.
305+
new_channel._unary_unary_interceptors.append(self._metrics_interceptor)
306+
new_channel._unary_stream_interceptors.append(self._metrics_interceptor)
307+
return new_channel
284308

285309
@property
286310
def universe_domain(self) -> str:
@@ -402,7 +426,7 @@ def _invalidate_channel_stubs(self):
402426
self.transport._stubs = {}
403427
self.transport._prep_wrapped_messages(self.client_info)
404428

405-
@CrossSync.convert(replace_symbols={"AsyncSwappableChannel": "SwappableChannel"})
429+
@CrossSync.convert
406430
async def _manage_channel(
407431
self,
408432
refresh_interval_min: float = 60 * 35,
@@ -427,10 +451,10 @@ async def _manage_channel(
427451
grace_period: time to allow previous channel to serve existing
428452
requests before closing, in seconds
429453
"""
430-
if not isinstance(self.transport.grpc_channel, AsyncSwappableChannel):
454+
if not isinstance(self.transport.grpc_channel, SwappableChannelType):
431455
warnings.warn("Channel does not support auto-refresh.")
432456
return
433-
super_channel: AsyncSwappableChannel = self.transport.grpc_channel
457+
super_channel: SwappableChannelType = self.transport.grpc_channel
434458
first_refresh = self._channel_init_time + random.uniform(
435459
refresh_interval_min, refresh_interval_max
436460
)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from google.cloud.bigtable.data._cross_sync import CrossSync
17+
18+
if CrossSync.is_async:
19+
from grpc.aio import UnaryUnaryClientInterceptor
20+
from grpc.aio import UnaryStreamClientInterceptor
21+
else:
22+
from grpc import UnaryUnaryClientInterceptor
23+
from grpc import UnaryStreamClientInterceptor
24+
25+
26+
__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.metrics_interceptor"
27+
28+
29+
@CrossSync.convert_class(sync_name="BigtableMetricsInterceptor")
30+
class AsyncBigtableMetricsInterceptor(
31+
UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor
32+
):
33+
"""
34+
An async gRPC interceptor to add client metadata and print server metadata.
35+
"""
36+
37+
@CrossSync.convert
38+
async def intercept_unary_unary(self, continuation, client_call_details, request):
39+
"""
40+
Interceptor for unary rpcs:
41+
- MutateRow
42+
- CheckAndMutateRow
43+
- ReadModifyWriteRow
44+
"""
45+
try:
46+
call = await continuation(client_call_details, request)
47+
return call
48+
except Exception as rpc_error:
49+
raise rpc_error
50+
51+
@CrossSync.convert
52+
async def intercept_unary_stream(self, continuation, client_call_details, request):
53+
"""
54+
Interceptor for streaming rpcs:
55+
- ReadRows
56+
- MutateRows
57+
- SampleRowKeys
58+
"""
59+
try:
60+
return self._streaming_generator_wrapper(
61+
await continuation(client_call_details, request)
62+
)
63+
except Exception as rpc_error:
64+
# handle errors while intializing stream
65+
raise rpc_error
66+
67+
@staticmethod
68+
@CrossSync.convert
69+
async def _streaming_generator_wrapper(call):
70+
"""
71+
Wrapped generator to be returned by intercept_unary_stream.
72+
"""
73+
try:
74+
async for response in call:
75+
yield response
76+
except Exception as e:
77+
# handle errors while processing stream
78+
raise e

google/cloud/bigtable/data/_sync_autogen/client.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# This file is automatically generated by CrossSync. Do not edit manually.
1818

1919
from __future__ import annotations
20-
from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING
20+
from typing import cast, Any, Callable, Optional, Set, Sequence, TYPE_CHECKING
2121
import abc
2222
import time
2323
import warnings
@@ -77,12 +77,18 @@
7777
from google.cloud.bigtable.data._cross_sync import CrossSync
7878
from typing import Iterable
7979
from grpc import insecure_channel
80+
from grpc import intercept_channel
8081
from google.cloud.bigtable_v2.services.bigtable.transports import (
8182
BigtableGrpcTransport as TransportType,
8283
)
8384
from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient
8485
from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE
85-
from google.cloud.bigtable.data._sync_autogen._swappable_channel import SwappableChannel
86+
from google.cloud.bigtable.data._sync_autogen._swappable_channel import (
87+
SwappableChannel as SwappableChannelType,
88+
)
89+
from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import (
90+
BigtableMetricsInterceptor as MetricsInterceptorType,
91+
)
8692

8793
if TYPE_CHECKING:
8894
from google.cloud.bigtable.data._helpers import RowKeySamples
@@ -145,6 +151,7 @@ def __init__(
145151
credentials = google.auth.credentials.AnonymousCredentials()
146152
if project is None:
147153
project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT
154+
self._metrics_interceptor = MetricsInterceptorType()
148155
ClientWithProject.__init__(
149156
self,
150157
credentials=credentials,
@@ -188,7 +195,7 @@ def __init__(
188195
stacklevel=2,
189196
)
190197

191-
def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannel:
198+
def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannelType:
192199
"""This method is called by the gapic transport to create a grpc channel.
193200
194201
The init arguments passed down are captured in a partial used by SwappableChannel
@@ -201,11 +208,20 @@ def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannel:
201208
- **kwargs: keyword arguments passed by the gapic layer to create a new channel with
202209
Returns:
203210
a custom wrapped swappable channel"""
211+
create_channel_fn: Callable[[], Channel]
204212
if self._emulator_host is not None:
205213
create_channel_fn = partial(insecure_channel, self._emulator_host)
206214
else:
207-
create_channel_fn = partial(TransportType.create_channel, *args, **kwargs)
208-
return SwappableChannel(create_channel_fn)
215+
216+
def sync_create_channel_fn():
217+
return intercept_channel(
218+
TransportType.create_channel(*args, **kwargs),
219+
self._metrics_interceptor,
220+
)
221+
222+
create_channel_fn = sync_create_channel_fn
223+
new_channel = SwappableChannelType(create_channel_fn)
224+
return new_channel
209225

210226
@property
211227
def universe_domain(self) -> str:
@@ -326,10 +342,10 @@ def _manage_channel(
326342
between `refresh_interval_min` and `refresh_interval_max`
327343
grace_period: time to allow previous channel to serve existing
328344
requests before closing, in seconds"""
329-
if not isinstance(self.transport.grpc_channel, SwappableChannel):
345+
if not isinstance(self.transport.grpc_channel, SwappableChannelType):
330346
warnings.warn("Channel does not support auto-refresh.")
331347
return
332-
super_channel: SwappableChannel = self.transport.grpc_channel
348+
super_channel: SwappableChannelType = self.transport.grpc_channel
333349
first_refresh = self._channel_init_time + random.uniform(
334350
refresh_interval_min, refresh_interval_max
335351
)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This file is automatically generated by CrossSync. Do not edit manually.
16+
17+
from __future__ import annotations
18+
from grpc import UnaryUnaryClientInterceptor
19+
from grpc import UnaryStreamClientInterceptor
20+
21+
22+
class BigtableMetricsInterceptor(
23+
UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor
24+
):
25+
"""
26+
An async gRPC interceptor to add client metadata and print server metadata.
27+
"""
28+
29+
def intercept_unary_unary(self, continuation, client_call_details, request):
30+
"""Interceptor for unary rpcs:
31+
- MutateRow
32+
- CheckAndMutateRow
33+
- ReadModifyWriteRow"""
34+
try:
35+
call = continuation(client_call_details, request)
36+
return call
37+
except Exception as rpc_error:
38+
raise rpc_error
39+
40+
def intercept_unary_stream(self, continuation, client_call_details, request):
41+
"""Interceptor for streaming rpcs:
42+
- ReadRows
43+
- MutateRows
44+
- SampleRowKeys"""
45+
try:
46+
return self._streaming_generator_wrapper(
47+
continuation(client_call_details, request)
48+
)
49+
except Exception as rpc_error:
50+
raise rpc_error
51+
52+
@staticmethod
53+
def _streaming_generator_wrapper(call):
54+
"""Wrapped generator to be returned by intercept_unary_stream."""
55+
try:
56+
for response in call:
57+
yield response
58+
except Exception as e:
59+
raise e

tests/system/data/test_system_async.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -285,23 +285,28 @@ async def test_channel_refresh(self, table_id, instance_id, temp_rows):
285285
async with client.get_table(instance_id, table_id) as table:
286286
rows = await table.read_rows({})
287287
channel_wrapper = client.transport.grpc_channel
288-
first_channel = client.transport.grpc_channel._channel
288+
first_channel = channel_wrapper._channel
289289
assert len(rows) == 2
290290
await CrossSync.sleep(2)
291291
rows_after_refresh = await table.read_rows({})
292292
assert len(rows_after_refresh) == 2
293293
assert client.transport.grpc_channel is channel_wrapper
294-
assert client.transport.grpc_channel._channel is not first_channel
295-
# ensure gapic's logging interceptor is still active
294+
updated_channel = channel_wrapper._channel
295+
assert updated_channel is not first_channel
296+
# ensure interceptors are kept (gapic's logging interceptor, and metric interceptor)
296297
if CrossSync.is_async:
297-
interceptors = (
298-
client.transport.grpc_channel._channel._unary_unary_interceptors
299-
)
300-
assert GapicInterceptor in [type(i) for i in interceptors]
298+
unary_interceptors = updated_channel._unary_unary_interceptors
299+
assert len(unary_interceptors) == 2
300+
assert GapicInterceptor in [type(i) for i in unary_interceptors]
301+
assert client._metrics_interceptor in unary_interceptors
302+
stream_interceptors = updated_channel._unary_stream_interceptors
303+
assert len(stream_interceptors) == 1
304+
assert client._metrics_interceptor in stream_interceptors
301305
else:
302306
assert isinstance(
303307
client.transport._logged_channel._interceptor, GapicInterceptor
304308
)
309+
assert updated_channel._interceptor == client._metrics_interceptor
305310
finally:
306311
await client.close()
307312

tests/system/data/test_system_autogen.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,16 +237,18 @@ def test_channel_refresh(self, table_id, instance_id, temp_rows):
237237
with client.get_table(instance_id, table_id) as table:
238238
rows = table.read_rows({})
239239
channel_wrapper = client.transport.grpc_channel
240-
first_channel = client.transport.grpc_channel._channel
240+
first_channel = channel_wrapper._channel
241241
assert len(rows) == 2
242242
CrossSync._Sync_Impl.sleep(2)
243243
rows_after_refresh = table.read_rows({})
244244
assert len(rows_after_refresh) == 2
245245
assert client.transport.grpc_channel is channel_wrapper
246-
assert client.transport.grpc_channel._channel is not first_channel
246+
updated_channel = channel_wrapper._channel
247+
assert updated_channel is not first_channel
247248
assert isinstance(
248249
client.transport._logged_channel._interceptor, GapicInterceptor
249250
)
251+
assert updated_channel._interceptor == client._metrics_interceptor
250252
finally:
251253
client.close()
252254

0 commit comments

Comments
 (0)