Skip to content

Commit

Permalink
Fix gRPC client interceptor for channels reused across traces
Browse files Browse the repository at this point in the history
In a typical web server gRPC clients/channels are reused across multiple
requests and hence across multiple traces. Previously the
`OpenCensusClientInterceptor` was instantiated for each channel with the
current tracer from the execution context. This would then lead to all
rpcs going through that channel to have the same tracer, essentially
grouping all rpcs under whatever happened to be the current trace when the
channel was created.

Instead instantiate `OpenCensusClientInterceptor` without a tracer by
default. The current tracer will be retrieved from the execution context at
the start of every rpc span.

In addition `OpenCensusClientInterceptor` was manipulating thread-local state
via the execution context. This seems unnecessary and misguided. The current
span state is already managed by the spans/context tracers. Setting the
tracer explicitly risks further subtle bugs.

Also removes unused method `OpenCensusClientInterceptor._end_span_between_context`.

Fixes #182
  • Loading branch information
Nik Haldimann committed Mar 5, 2019
1 parent ab363df commit b741f93
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 56 deletions.
18 changes: 3 additions & 15 deletions opencensus/trace/ext/google_cloud_clientlibs/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,7 @@ def call(*args, **kwargs):

try:
host = kwargs.get('host')
if tracer is None:
_tracer = execution_context.get_opencensus_tracer()
else: # pragma: NO COVER
_tracer = tracer
tracer_interceptor = OpenCensusClientInterceptor(_tracer, host)
tracer_interceptor = OpenCensusClientInterceptor(tracer, host)
intercepted_channel = grpc.intercept_channel(
channel, tracer_interceptor)
return intercepted_channel # pragma: NO COVER
Expand All @@ -112,11 +108,7 @@ def call(*args, **kwargs):

try:
target = kwargs.get('target')
if tracer is None:
_tracer = execution_context.get_opencensus_tracer()
else: # pragma: NO COVER
_tracer = tracer
tracer_interceptor = OpenCensusClientInterceptor(_tracer, target)
tracer_interceptor = OpenCensusClientInterceptor(tracer, target)
intercepted_channel = grpc.intercept_channel(
channel, tracer_interceptor)
return intercepted_channel # pragma: NO COVER
Expand All @@ -135,11 +127,7 @@ def call(*args, **kwargs):

try:
target = kwargs.get('target')
if tracer is None:
_tracer = execution_context.get_opencensus_tracer()
else: # pragma: NO COVER
_tracer = tracer
tracer_interceptor = OpenCensusClientInterceptor(_tracer, target)
tracer_interceptor = OpenCensusClientInterceptor(tracer, target)
intercepted_channel = grpc.intercept_channel(
channel, tracer_interceptor)
return intercepted_channel # pragma: NO COVER
Expand Down
8 changes: 0 additions & 8 deletions opencensus/trace/ext/grpc/client_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,8 @@ def _start_client_span(self, client_call_details):
attribute_key=attributes_helper.GRPC_ATTRIBUTES.get(GRPC_METHOD),
attribute_value=str(client_call_details.method))

execution_context.set_opencensus_tracer(self.tracer)
execution_context.set_current_span(span)

return span

def _end_span_between_context(self, current_span):
execution_context.set_current_span(current_span)
self.tracer.end_span()

def _intercept_call(
self, client_call_details, request_iterator, grpc_type
):
Expand Down Expand Up @@ -139,7 +132,6 @@ def callback(future_response):
span=current_span,
message_event_type=time_event.Type.RECEIVED,
)
execution_context.set_current_span(current_span)
self._trace_future_exception(future_response)
self.tracer.end_span()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,62 +74,46 @@ def test_trace_http(self):
self.assertTrue(mock_trace_requests.called)

def test_wrap_make_secure_channel(self):
mock_tracer = mock.Mock()
mock_interceptor = mock.Mock()
mock_func = mock.Mock()

patch_tracer = mock.patch(
'opencensus.trace.ext.google_cloud_clientlibs.trace'
'.execution_context.'
'get_opencensus_tracer',
return_value=mock_tracer)
patch_interceptor = mock.patch(
'opencensus.trace.ext.google_cloud_clientlibs.trace'
'.OpenCensusClientInterceptor', mock_interceptor)

wrapped = trace.wrap_make_secure_channel(mock_func)

with patch_tracer, patch_interceptor:
with patch_interceptor:
wrapped()

self.assertTrue(mock_interceptor.called)

def test_wrap_insecure_channel(self):
mock_tracer = mock.Mock()
mock_interceptor = mock.Mock()
mock_func = mock.Mock()

patch_tracer = mock.patch(
'opencensus.trace.ext.google_cloud_clientlibs.trace'
'.execution_context.get_opencensus_tracer',
return_value=mock_tracer)
patch_interceptor = mock.patch(
'opencensus.trace.ext.google_cloud_clientlibs.trace'
'.OpenCensusClientInterceptor', mock_interceptor)

wrapped = trace.wrap_insecure_channel(mock_func)

with patch_tracer, patch_interceptor:
with patch_interceptor:
wrapped()

self.assertTrue(mock_interceptor.called)

def test_wrap_create_channel(self):
mock_tracer = mock.Mock()
mock_interceptor = mock.Mock()
mock_func = mock.Mock()

patch_tracer = mock.patch(
'opencensus.trace.ext.google_cloud_clientlibs.trace'
'.execution_context.get_opencensus_tracer',
return_value=mock_tracer)
patch_interceptor = mock.patch(
'opencensus.trace.ext.google_cloud_clientlibs.trace'
'.OpenCensusClientInterceptor', mock_interceptor)

wrapped = trace.wrap_create_channel(mock_func)

with patch_tracer, patch_interceptor:
with patch_interceptor:
wrapped()

self.assertTrue(mock_interceptor.called)
14 changes: 0 additions & 14 deletions tests/unit/trace/ext/grpc/test_client_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,6 @@ def test__start_client_span(self):
self.assertTrue(tracer.start_span.called)
self.assertTrue(tracer.add_attribute_to_current_span.called)

def test__end_span_between_context(self):
from opencensus.trace import execution_context

current_span = mock.Mock()
tracer = mock.Mock()
interceptor = client_interceptor.OpenCensusClientInterceptor(
tracer=tracer, host_port='test')
interceptor._end_span_between_context(current_span)

span_in_context = execution_context.get_current_span()

self.assertEqual(span_in_context, current_span)
self.assertTrue(tracer.end_span.called)

def test__intercept_call_metadata_none(self):
tracer = mock.Mock()
tracer.span_context = mock.Mock()
Expand Down

0 comments on commit b741f93

Please sign in to comment.