Skip to content

Commit

Permalink
added request and response hooks for grpc client
Browse files Browse the repository at this point in the history
  • Loading branch information
prsnca committed Mar 2, 2023
1 parent 4a859e3 commit b2a45a3
Show file tree
Hide file tree
Showing 6 changed files with 380 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix exception in Urllib3 when dealing with filelike body.
([#1399](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1399))

- Add request and response hooks for GRPC instrumentation (client only)
([#1706](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1706))

### Added

- Add connection attributes to sqlalchemy connect span
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ def __init__(self, filter_=None):
else:
filter_ = any_of(filter_, excluded_service_filter)
self._filter = filter_
self._request_hook = None
self._response_hook = None

# Figures out which channel type we need to wrap
def _which_channel(self, kwargs):
Expand All @@ -455,6 +457,8 @@ def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
self._request_hook = kwargs.get("request_hook")
self._response_hook = kwargs.get("response_hook")
for ctype in self._which_channel(kwargs):
_wrap(
"grpc",
Expand All @@ -469,11 +473,15 @@ def _uninstrument(self, **kwargs):
def wrapper_fn(self, original_func, instance, args, kwargs):
channel = original_func(*args, **kwargs)
tracer_provider = kwargs.get("tracer_provider")
request_hook = self._request_hook
response_hook = self._response_hook
return intercept_channel(
channel,
client_interceptor(
tracer_provider=tracer_provider,
filter_=self._filter,
request_hook=request_hook,
response_hook=response_hook,
),
)

Expand All @@ -499,6 +507,8 @@ def __init__(self, filter_=None):
else:
filter_ = any_of(filter_, excluded_service_filter)
self._filter = filter_
self._request_hook = None
self._response_hook = None

def instrumentation_dependencies(self) -> Collection[str]:
return _instruments
Expand All @@ -507,20 +517,28 @@ def _add_interceptors(self, tracer_provider, kwargs):
if "interceptors" in kwargs and kwargs["interceptors"]:
kwargs["interceptors"] = (
aio_client_interceptors(
tracer_provider=tracer_provider, filter_=self._filter
tracer_provider=tracer_provider,
filter_=self._filter,
request_hook=self._request_hook,
response_hook=self._response_hook,
)
+ kwargs["interceptors"]
)
else:
kwargs["interceptors"] = aio_client_interceptors(
tracer_provider=tracer_provider, filter_=self._filter
tracer_provider=tracer_provider,
filter_=self._filter,
request_hook=self._request_hook,
response_hook=self._response_hook,
)

return kwargs

def _instrument(self, **kwargs):
self._original_insecure = grpc.aio.insecure_channel
self._original_secure = grpc.aio.secure_channel
self._request_hook = kwargs.get("request_hook")
self._response_hook = kwargs.get("response_hook")
tracer_provider = kwargs.get("tracer_provider")

def insecure(*args, **kwargs):
Expand All @@ -541,7 +559,9 @@ def _uninstrument(self, **kwargs):
grpc.aio.secure_channel = self._original_secure


def client_interceptor(tracer_provider=None, filter_=None):
def client_interceptor(
tracer_provider=None, filter_=None, request_hook=None, response_hook=None
):
"""Create a gRPC client channel interceptor.
Args:
Expand All @@ -558,7 +578,12 @@ def client_interceptor(tracer_provider=None, filter_=None):

tracer = trace.get_tracer(__name__, __version__, tracer_provider)

return _client.OpenTelemetryClientInterceptor(tracer, filter_=filter_)
return _client.OpenTelemetryClientInterceptor(
tracer,
filter_=filter_,
request_hook=request_hook,
response_hook=response_hook,
)


def server_interceptor(tracer_provider=None, filter_=None):
Expand All @@ -581,7 +606,9 @@ def server_interceptor(tracer_provider=None, filter_=None):
return _server.OpenTelemetryServerInterceptor(tracer, filter_=filter_)


def aio_client_interceptors(tracer_provider=None, filter_=None):
def aio_client_interceptors(
tracer_provider=None, filter_=None, request_hook=None, response_hook=None
):
"""Create a gRPC client channel interceptor.
Args:
Expand All @@ -595,10 +622,30 @@ def aio_client_interceptors(tracer_provider=None, filter_=None):
tracer = trace.get_tracer(__name__, __version__, tracer_provider)

return [
_aio_client.UnaryUnaryAioClientInterceptor(tracer, filter_=filter_),
_aio_client.UnaryStreamAioClientInterceptor(tracer, filter_=filter_),
_aio_client.StreamUnaryAioClientInterceptor(tracer, filter_=filter_),
_aio_client.StreamStreamAioClientInterceptor(tracer, filter_=filter_),
_aio_client.UnaryUnaryAioClientInterceptor(
tracer,
filter_=filter_,
request_hook=request_hook,
response_hook=response_hook,
),
_aio_client.UnaryStreamAioClientInterceptor(
tracer,
filter_=filter_,
request_hook=request_hook,
response_hook=response_hook,
),
_aio_client.StreamUnaryAioClientInterceptor(
tracer,
filter_=filter_,
request_hook=request_hook,
response_hook=response_hook,
),
_aio_client.StreamStreamAioClientInterceptor(
tracer,
filter_=filter_,
request_hook=request_hook,
response_hook=response_hook,
),
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import functools
import logging
from collections import OrderedDict

import grpc
Expand All @@ -28,8 +29,10 @@
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace.status import Status, StatusCode

logger = logging.getLogger(__name__)

def _unary_done_callback(span, code, details):

def _unary_done_callback(span, code, details, response_hook):
def callback(call):
try:
span.set_attribute(
Expand All @@ -43,6 +46,8 @@ def callback(call):
description=details,
)
)
response_hook(span, details)

finally:
span.end()

Expand Down Expand Up @@ -110,7 +115,11 @@ async def _wrap_unary_response(self, continuation, span):
code = await call.code()
details = await call.details()

call.add_done_callback(_unary_done_callback(span, code, details))
call.add_done_callback(
_unary_done_callback(
span, code, details, self._call_response_hook
)
)

return call
except grpc.aio.AioRpcError as exc:
Expand All @@ -120,6 +129,8 @@ async def _wrap_unary_response(self, continuation, span):
async def _wrap_stream_response(self, span, call):
try:
async for response in call:
if self._response_hook:
self._call_response_hook(span, response)
yield response
except Exception as exc:
self.add_error_details_to_span(span, exc)
Expand Down Expand Up @@ -151,6 +162,9 @@ async def intercept_unary_unary(
) as span:
new_details = self.propagate_trace_in_details(client_call_details)

if self._request_hook:
self._call_request_hook(span, request)

continuation_with_args = functools.partial(
continuation, new_details, request
)
Expand All @@ -175,7 +189,8 @@ async def intercept_unary_stream(
new_details = self.propagate_trace_in_details(client_call_details)

resp = await continuation(new_details, request)

if self._request_hook:
self._call_request_hook(span, request)
return self._wrap_stream_response(span, resp)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@

"""Implementation of the invocation-side open-telemetry interceptor."""

import logging
from collections import OrderedDict
from typing import MutableMapping
from typing import Callable, MutableMapping

import grpc

Expand All @@ -33,6 +34,8 @@
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace.status import Status, StatusCode

logger = logging.getLogger(__name__)


class _CarrierSetter(Setter):
"""We use a custom setter in order to be able to lower case
Expand All @@ -59,12 +62,27 @@ def callback(response_future):
return callback


def _safe_invoke(function: Callable, *args):
function_name = "<unknown>"
try:
function_name = function.__name__
function(*args)
except Exception as ex: # pylint:disable=broad-except
logger.error(
"Error when invoking function '%s'", function_name, exc_info=ex
)


class OpenTelemetryClientInterceptor(
grpcext.UnaryClientInterceptor, grpcext.StreamClientInterceptor
):
def __init__(self, tracer, filter_=None):
def __init__(
self, tracer, filter_=None, request_hook=None, response_hook=None
):
self._tracer = tracer
self._filter = filter_
self._request_hook = request_hook
self._response_hook = response_hook

def _start_span(self, method, **kwargs):
service, meth = method.lstrip("/").split("/", 1)
Expand Down Expand Up @@ -99,6 +117,8 @@ def _trace_result(self, span, rpc_info, result):
if isinstance(result, tuple):
response = result[0]
rpc_info.response = response
if self._response_hook:
self._call_response_hook(span, response)
span.end()
return result

Expand Down Expand Up @@ -127,7 +147,8 @@ def _intercept(self, request, metadata, client_info, invoker):
timeout=client_info.timeout,
request=request,
)

if self._request_hook:
self._call_request_hook(span, request)
result = invoker(request, metadata)
except Exception as exc:
if isinstance(exc, grpc.RpcError):
Expand All @@ -148,6 +169,16 @@ def _intercept(self, request, metadata, client_info, invoker):
span.end()
return self._trace_result(span, rpc_info, result)

def _call_request_hook(self, span, request):
if not callable(self._request_hook):
return
_safe_invoke(self._request_hook, span, request)

def _call_response_hook(self, span, response):
if not callable(self._response_hook):
return
_safe_invoke(self._response_hook, span, response)

def intercept_unary(self, request, metadata, client_info, invoker):
if self._filter is not None and not self._filter(client_info):
return invoker(request, metadata)
Expand Down
Loading

0 comments on commit b2a45a3

Please sign in to comment.