Skip to content

Commit

Permalink
fix asynchonous unary call traces (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
sengjea committed Jul 12, 2021
1 parent 753e228 commit 2ee2cf3
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 82 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#545](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/545))
- `openelemetry-sdk-extension-aws` Take a dependency on `opentelemetry-sdk`
([#558](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/558))

### Changed
- `opentelemetry-instrumentation-tornado` properly instrument work done in tornado on_finish method.
([#499](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/499))
Expand All @@ -33,6 +34,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updating dependency for opentelemetry api/sdk packages to support major version instead
of pinning to specific versions.
([#567](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/567))
- `opentelemetry-instrumentation-grpc` Fixed asynchonous unary call traces
([#536](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/536))

### Added
- `opentelemetry-instrumentation-httpx` Add `httpx` instrumentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,6 @@
from opentelemetry.trace.status import Status, StatusCode


class _GuardedSpan:
def __init__(self, span):
self.span = span
self.generated_span = None
self._engaged = True

def __enter__(self):
self.generated_span = self.span.__enter__()
return self

def __exit__(self, *args, **kwargs):
if self._engaged:
self.generated_span = None
return self.span.__exit__(*args, **kwargs)
return False

def release(self):
self._engaged = False
return self.span


class _CarrierSetter(Setter):
"""We use a custom setter in order to be able to lower case
keys as is required by grpc.
Expand All @@ -68,7 +47,7 @@ def set(self, carrier: MutableMapping[str, str], key: str, value: str):

def _make_future_done_callback(span, rpc_info):
def callback(response_future):
with span:
with trace.use_span(span, end_on_exit=True):
code = response_future.code()
if code != grpc.StatusCode.OK:
rpc_info.error = code
Expand All @@ -85,7 +64,7 @@ class OpenTelemetryClientInterceptor(
def __init__(self, tracer):
self._tracer = tracer

def _start_span(self, method):
def _start_span(self, method, **kwargs):
service, meth = method.lstrip("/").split("/", 1)
attributes = {
SpanAttributes.RPC_SYSTEM: "grpc",
Expand All @@ -95,16 +74,19 @@ def _start_span(self, method):
}

return self._tracer.start_as_current_span(
name=method, kind=trace.SpanKind.CLIENT, attributes=attributes
name=method,
kind=trace.SpanKind.CLIENT,
attributes=attributes,
**kwargs,
)

# pylint:disable=no-self-use
def _trace_result(self, guarded_span, rpc_info, result):
# If the RPC is called asynchronously, release the guard and add a
# callback so that the span can be finished once the future is done.
def _trace_result(self, span, rpc_info, result):
# If the RPC is called asynchronously, add a callback to end the span
# when the future is done, else end the span immediately
if isinstance(result, grpc.Future):
result.add_done_callback(
_make_future_done_callback(guarded_span.release(), rpc_info)
_make_future_done_callback(span, rpc_info)
)
return result
response = result
Expand All @@ -115,41 +97,54 @@ def _trace_result(self, guarded_span, rpc_info, result):
if isinstance(result, tuple):
response = result[0]
rpc_info.response = response

span.end()
return result

def _start_guarded_span(self, *args, **kwargs):
return _GuardedSpan(self._start_span(*args, **kwargs))

def intercept_unary(self, request, metadata, client_info, invoker):
def _intercept(self, request, metadata, client_info, invoker):
if not metadata:
mutable_metadata = OrderedDict()
else:
mutable_metadata = OrderedDict(metadata)

with self._start_guarded_span(client_info.full_method) as guarded_span:
inject(mutable_metadata, setter=_carrier_setter)
metadata = tuple(mutable_metadata.items())

rpc_info = RpcInfo(
full_method=client_info.full_method,
metadata=metadata,
timeout=client_info.timeout,
request=request,
)

with self._start_span(
client_info.full_method,
end_on_exit=False,
record_exception=False,
set_status_on_exception=False,
) as span:
result = None
try:
result = invoker(request, metadata)
except grpc.RpcError as err:
guarded_span.generated_span.set_status(
Status(StatusCode.ERROR)
inject(mutable_metadata, setter=_carrier_setter)
metadata = tuple(mutable_metadata.items())

rpc_info = RpcInfo(
full_method=client_info.full_method,
metadata=metadata,
timeout=client_info.timeout,
request=request,
)
guarded_span.generated_span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE, err.code().value[0]

result = invoker(request, metadata)
except Exception as exc:
if isinstance(exc, grpc.RpcError):
span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE,
exc.code().value[0],
)
span.set_status(
Status(
status_code=StatusCode.ERROR,
description="{}: {}".format(type(exc).__name__, exc),
)
)
raise err
span.record_exception(exc)
raise exc
finally:
if not result:
span.end()
return self._trace_result(span, rpc_info, result)

return self._trace_result(guarded_span, rpc_info, result)
def intercept_unary(self, request, metadata, client_info, invoker):
return self._intercept(request, metadata, client_info, invoker)

# For RPCs that stream responses, the result can be a generator. To record
# the span across the generated responses and detect any errors, we wrap
Expand Down Expand Up @@ -194,32 +189,6 @@ def intercept_stream(
request_or_iterator, metadata, client_info, invoker
)

if not metadata:
mutable_metadata = OrderedDict()
else:
mutable_metadata = OrderedDict(metadata)

with self._start_guarded_span(client_info.full_method) as guarded_span:
inject(mutable_metadata, setter=_carrier_setter)
metadata = tuple(mutable_metadata.items())
rpc_info = RpcInfo(
full_method=client_info.full_method,
metadata=metadata,
timeout=client_info.timeout,
request=request_or_iterator,
)

rpc_info.request = request_or_iterator

try:
result = invoker(request_or_iterator, metadata)
except grpc.RpcError as err:
guarded_span.generated_span.set_status(
Status(StatusCode.ERROR)
)
guarded_span.generated_span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE, err.code().value[0],
)
raise err

return self._trace_result(guarded_span, rpc_info, result)
return self._intercept(
request_or_iterator, metadata, client_info, invoker
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def simple_method(stub, error=False):
stub.SimpleMethod(request)


def simple_method_future(stub, error=False):
request = Request(
client_id=CLIENT_ID, request_data="error" if error else "data"
)
return stub.SimpleMethod.future(request)


def client_streaming_method(stub, error=False):
# create a generator
def request_messages():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
client_streaming_method,
server_streaming_method,
simple_method,
simple_method_future,
)
from ._server import create_test_server
from .protobuf.test_server_pb2 import Request
Expand Down Expand Up @@ -100,6 +101,20 @@ def tearDown(self):
self.server.stop(None)
self.channel.close()

def test_unary_unary_future(self):
simple_method_future(self._stub).result()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]

self.assertEqual(span.name, "/GRPCTestServer/SimpleMethod")
self.assertIs(span.kind, trace.SpanKind.CLIENT)

# Check version and name in span's instrumentation info
self.check_span_instrumentation_info(
span, opentelemetry.instrumentation.grpc
)

def test_unary_unary(self):
simple_method(self._stub)
spans = self.memory_exporter.get_finished_spans()
Expand Down

0 comments on commit 2ee2cf3

Please sign in to comment.