Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for grpc aio server interceptor #1870

Merged
68 changes: 68 additions & 0 deletions elasticapm/contrib/grpc/async_server_interceptor.py
@@ -0,0 +1,68 @@
# BSD 3-Clause License
#
# Copyright (c) 2022, Elasticsearch BV
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import inspect

import grpc

import elasticapm
from elasticapm.contrib.grpc.server_interceptor import _ServicerContextWrapper, _wrap_rpc_behavior, get_trace_parent


class _AsyncServerInterceptor(grpc.aio.ServerInterceptor):
async def intercept_service(self, continuation, handler_call_details):
def transaction_wrapper(behavior, request_streaming, response_streaming):
async def _interceptor(request_or_iterator, context):
if request_streaming or response_streaming: # only unary-unary is supported
return behavior(request_or_iterator, context)
tp = get_trace_parent(handler_call_details)
client = elasticapm.get_client()
transaction = client.begin_transaction("request", trace_parent=tp)
try:
result = behavior(request_or_iterator, _ServicerContextWrapper(context, transaction))

# This is so we can support both sync and async rpc functions
if inspect.isawaitable(result):
result = await result

if transaction and not transaction.outcome:
transaction.set_success()
return result
except Exception:
if transaction:
transaction.set_failure()
client.capture_exception(handled=False)
raise
finally:
client.end_transaction(name=handler_call_details.method)

return _interceptor

return _wrap_rpc_behavior(await continuation(handler_call_details), transaction_wrapper)
24 changes: 14 additions & 10 deletions elasticapm/contrib/grpc/server_interceptor.py
Expand Up @@ -62,6 +62,19 @@ def _wrap_rpc_behavior(handler, continuation):
)


def get_trace_parent(handler_call_details):
traceparent, tracestate = None, None
for metadata in handler_call_details.invocation_metadata:
if metadata.key == "traceparent":
traceparent = metadata.value
elif metadata.key == "tracestate":
tracestate = metadata.key
if traceparent:
return TraceParent.from_string(traceparent, tracestate)
else:
return None


class _ServicerContextWrapper(wrapt.ObjectProxy):
def __init__(self, wrapped, transaction):
self._self_transaction = transaction
Expand All @@ -87,16 +100,7 @@ def transaction_wrapper(behavior, request_streaming, response_streaming):
def _interceptor(request_or_iterator, context):
if request_streaming or response_streaming: # only unary-unary is supported
return behavior(request_or_iterator, context)
traceparent, tracestate = None, None
for metadata in handler_call_details.invocation_metadata:
if metadata.key == "traceparent":
traceparent = metadata.value
elif metadata.key == "tracestate":
tracestate = metadata.key
if traceparent:
tp = TraceParent.from_string(traceparent, tracestate)
else:
tp = None
tp = get_trace_parent(handler_call_details)
client = elasticapm.get_client()
transaction = client.begin_transaction("request", trace_parent=tp)
try:
Expand Down
29 changes: 29 additions & 0 deletions elasticapm/instrumentation/packages/grpc.py
Expand Up @@ -29,6 +29,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


from pkg_resources import parse_version

from elasticapm.instrumentation.packages.asyncio.base import AbstractInstrumentedModule
from elasticapm.utils.logging import get_logger

Expand Down Expand Up @@ -74,3 +76,30 @@ def call(self, module, method, wrapped, instance, args, kwargs):
else:
kwargs["interceptors"] = interceptors
return wrapped(*args, **kwargs)


class GRPCAsyncServerInstrumentation(AbstractInstrumentedModule):
name = "grpc_async_server_instrumentation"
creates_transactions = True
instrument_list = [("grpc", "aio.server")]

def get_instrument_list(self):
import grpc
felipou marked this conversation as resolved.
Show resolved Hide resolved

# Check against the oldest version that I believe has the expected API
if parse_version(grpc.__version__) >= parse_version("1.33.1") and hasattr(grpc, "aio"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how I'm checking if we should try to instrument the async version, but I'm not sure what is the recommended way for doing this kind of thing here. I've thought of leaving just the check for the aio module's existence (and perhaps add a check for a server attribute inside grpc.aio), but since I wrote the version check first, I just left it there too.

return super().get_instrument_list()
else:
return []

def call(self, module, method, wrapped, instance, args, kwargs):
from elasticapm.contrib.grpc.async_server_interceptor import _AsyncServerInterceptor

interceptors = kwargs.get("interceptors") or (args[2] if len(args) > 2 else [])
interceptors.insert(0, _AsyncServerInterceptor())
if len(args) > 2:
args = list(args)
args[2] = interceptors
else:
kwargs["interceptors"] = interceptors
return wrapped(*args, **kwargs)
1 change: 1 addition & 0 deletions elasticapm/instrumentation/register.py
Expand Up @@ -70,6 +70,7 @@
"elasticapm.instrumentation.packages.kafka.KafkaInstrumentation",
"elasticapm.instrumentation.packages.grpc.GRPCClientInstrumentation",
"elasticapm.instrumentation.packages.grpc.GRPCServerInstrumentation",
"elasticapm.instrumentation.packages.grpc.GRPCAsyncServerInstrumentation",
basepi marked this conversation as resolved.
Show resolved Hide resolved
}

if sys.version_info >= (3, 7):
Expand Down
42 changes: 41 additions & 1 deletion tests/contrib/grpc/grpc_app/server.py
Expand Up @@ -28,7 +28,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import asyncio
import logging
import os
import sys
from concurrent import futures

Expand Down Expand Up @@ -67,6 +69,30 @@ def GetServerResponseException(self, request, context):
raise Exception("oh no")


class TestServiceAsync(pb2_grpc.TestServiceServicer):
def __init__(self, *args, **kwargs):
pass

async def GetServerResponse(self, request, context):
message = request.message
result = f'Hello I am up and running received "{message}" message from you'
result = {"message": result, "received": True}

return pb2.MessageResponse(**result)

async def GetServerResponseAbort(self, request, context):
await context.abort(grpc.StatusCode.INTERNAL, "foo")

async def GetServerResponseUnavailable(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNAVAILABLE)
context.set_details("Method not available")
return pb2.MessageResponse(message="foo", received=True)

async def GetServerResponseException(self, request, context):
raise Exception("oh no")


def serve(port):
apm_client = GRPCApmClient(
service_name="grpc-server", disable_metrics="*", api_request_time="100ms", central_config="False"
Expand All @@ -78,10 +104,24 @@ def serve(port):
server.wait_for_termination()


async def serve_async(port):
apm_client = GRPCApmClient(
service_name="grpc-server", disable_metrics="*", api_request_time="100ms", central_config="False"
)
server = grpc.aio.server()
pb2_grpc.add_TestServiceServicer_to_server(TestServiceAsync(), server)
server.add_insecure_port(f"[::]:{port}")
await server.start()
await server.wait_for_termination()


if __name__ == "__main__":
if len(sys.argv) > 1:
port = sys.argv[1]
else:
port = "50051"
logging.basicConfig()
serve(port)
if os.environ.get("GRPC_SERVER_ASYNC") == "1":
asyncio.run(serve_async(port))
else:
serve(port)
36 changes: 28 additions & 8 deletions tests/contrib/grpc/grpc_client_tests.py
Expand Up @@ -49,11 +49,14 @@
from .grpc_app.testgrpc_pb2_grpc import TestServiceStub


@pytest.fixture()
def grpc_server(validating_httpserver, request):
def setup_env(request, validating_httpserver):
config = getattr(request, "param", {})
env = {f"ELASTIC_APM_{k.upper()}": str(v) for k, v in config.items()}
env.setdefault("ELASTIC_APM_SERVER_URL", validating_httpserver.url)
return env


def setup_grpc_server(env):
free_port = get_free_port()
server_proc = subprocess.Popen(
[os.path.join(sys.prefix, "bin", "python"), "-m", "tests.contrib.grpc.grpc_app.server", str(free_port)],
Expand All @@ -62,15 +65,32 @@ def grpc_server(validating_httpserver, request):
env=env,
)
wait_for_open_port(free_port)
yield f"localhost:{free_port}"
server_proc.terminate()
return server_proc, free_port


@pytest.fixture()
def grpc_client_and_server_url(grpc_server):
test_channel = grpc.insecure_channel(grpc_server)
def env_fixture(validating_httpserver, request):
env = setup_env(request, validating_httpserver)
return env


if hasattr(grpc, "aio"):
grpc_server_fixture_params = ["async", "sync"]
else:
grpc_server_fixture_params = ["sync"]


@pytest.fixture(params=grpc_server_fixture_params)
def grpc_client_and_server_url(env_fixture, request):
env = {k: v for k, v in env_fixture.items()}
if request.param == "async":
env["GRPC_SERVER_ASYNC"] = "1"
server_proc, free_port = setup_grpc_server(env)
server_addr = f"localhost:{free_port}"
test_channel = grpc.insecure_channel(server_addr)
test_client = TestServiceStub(test_channel)
yield test_client, grpc_server
yield test_client, server_addr
server_proc.terminate()


def test_grpc_client_server_instrumentation(instrument, sending_elasticapm_client, grpc_client_and_server_url):
Expand Down Expand Up @@ -200,7 +220,7 @@ def test_grpc_client_unsampled_span(instrument, sending_elasticapm_client, grpc_


@pytest.mark.parametrize(
"grpc_server",
"env_fixture",
[
{
"recording": "False",
Expand Down