diff --git a/opencensus/trace/ext/grpc/server_interceptor.py b/opencensus/trace/ext/grpc/server_interceptor.py index 86e633e14..6e07099ab 100644 --- a/opencensus/trace/ext/grpc/server_interceptor.py +++ b/opencensus/trace/ext/grpc/server_interceptor.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections +import logging import grpc -import logging from opencensus.trace import attributes_helper from opencensus.trace import tracer as tracer_module +from opencensus.trace import execution_context from opencensus.trace.ext import grpc as oc_grpc from opencensus.trace.propagation import binary_format @@ -24,24 +26,72 @@ ATTRIBUTE_ERROR_NAME = 'ERROR_NAME' ATTRIBUTE_ERROR_MESSAGE = 'ERROR_MESSAGE' +RpcRequestInfo = collections.namedtuple( + 'RPCRequestInfo', ('request', 'context') +) +RpcResponseInfo = collections.namedtuple( + 'RPCCallbackInfo', ('request', 'context', 'response', 'exc') +) + + +class RpcMethodHandlerWrapper(object): + """Wraps a grpc RPCMethodHandler and records the variables about the + execution context and response + """ + + def __init__( + self, handler, pre_handler_callbacks=None, post_handler_callbacks=None + ): + """ + :param handler: instance of RpcMethodHandler + + :param pre_handler_callbacks: iterable of callbacks that accept an + instance of RpcRequestInfo that are called before the server handler + + :param post_handler_callbacks: iterable of callbacks that accept an + instance of RpcResponseInfo that are called after the server + handler finishes execution + """ + self.handler = handler + self._pre_handler_callbacks = pre_handler_callbacks or [] + self._post_handler_callbacks = post_handler_callbacks or [] + + def proxy(self, prop_name): + def _wrapper(request, context, *args, **kwargs): + for callback in self._pre_handler_callbacks: + callback(RpcRequestInfo(request, context)) + exc = None + response = None + try: + response = getattr( + self.handler, prop_name + )(request, context, *args, **kwargs) + except Exception as e: + logging.error(e) + exc = e + raise + finally: + for callback in self._post_handler_callbacks: + callback(RpcResponseInfo(request, context, response, exc)) + return response + + return _wrapper -class OpenCensusServerInterceptor(grpc.ServerInterceptor): + def __getattr__(self, item): + if item in ( + 'unary_unary', 'unary_stream', 'stream_unary', 'stream_stream' + ): + return self.proxy(item) + return getattr(self.handler, item) + +class OpenCensusServerInterceptor(grpc.ServerInterceptor): def __init__(self, sampler=None, exporter=None): self.sampler = sampler self.exporter = exporter - def _start_server_span(self, tracer): - span = tracer.start_span(name='grpc_server') - tracer.add_attribute_to_current_span( - attribute_key=attributes_helper.COMMON_ATTRIBUTES.get( - ATTRIBUTE_COMPONENT), - attribute_value='grpc') - - return span - - def intercept_handler(self, continuation, handler_call_details): - metadata = handler_call_details.invocation_metadata + def _start_server_span(self, rpc_request_info): + metadata = rpc_request_info.context.invocation_metadata() span_context = None if metadata is not None: @@ -55,21 +105,30 @@ def intercept_handler(self, continuation, handler_call_details): sampler=self.sampler, exporter=self.exporter) - with self._start_server_span(tracer): - response = None + span = tracer.start_span(name='grpc_server') + tracer.add_attribute_to_current_span( + attribute_key=attributes_helper.COMMON_ATTRIBUTES.get( + ATTRIBUTE_COMPONENT), + attribute_value='grpc') - try: - response = continuation(handler_call_details) - except Exception as e: # pragma: NO COVER - logging.error(e) - tracer.add_attribute_to_current_span( - attributes_helper.COMMON_ATTRIBUTES.get( - ATTRIBUTE_ERROR_MESSAGE), - str(e)) - tracer.end_span() - raise + execution_context.set_opencensus_tracer(tracer) + execution_context.set_current_span(span) + + def _end_server_span(self, rpc_response_info): + tracer = execution_context.get_opencensus_tracer() + if rpc_response_info.exc is not None: + tracer.add_attribute_to_current_span( + attributes_helper.COMMON_ATTRIBUTES.get( + ATTRIBUTE_ERROR_MESSAGE), + str(rpc_response_info.exc)) + tracer.end_span() - return response + def intercept_handler(self, continuation, handler_call_details): + return RpcMethodHandlerWrapper( + continuation(handler_call_details), + pre_handler_callbacks=[self._start_server_span], + post_handler_callbacks=[self._end_server_span] + ) def intercept_service(self, continuation, handler_call_details): return self.intercept_handler(continuation, handler_call_details) diff --git a/tests/unit/trace/ext/grpc/test_server_interceptor.py b/tests/unit/trace/ext/grpc/test_server_interceptor.py index dee3aaa48..d7c4d29e9 100644 --- a/tests/unit/trace/ext/grpc/test_server_interceptor.py +++ b/tests/unit/trace/ext/grpc/test_server_interceptor.py @@ -16,8 +16,8 @@ import mock -from opencensus.trace.ext.grpc import server_interceptor from opencensus.trace import execution_context +from opencensus.trace.ext.grpc import server_interceptor class TestOpenCensusServerInterceptor(unittest.TestCase): @@ -29,19 +29,28 @@ def test_constructor(self): self.assertEqual(interceptor.sampler, sampler) self.assertEqual(interceptor.exporter, exporter) + def test_rpc_handler_wrapper(self): + """Ensure that RPCHandlerWrapper proxies to the unerlying handler""" + mock_handler = mock.Mock() + mock_handler.response_streaming = False + wrapper = server_interceptor.RpcMethodHandlerWrapper(mock_handler) + self.assertEqual(wrapper.response_streaming, False) + def test_intercept_handler_no_metadata(self): current_span = mock.Mock() mock_tracer = MockTracer(None, None, None) patch = mock.patch( 'opencensus.trace.ext.grpc.server_interceptor.tracer_module.Tracer', MockTracer) - mock_details = mock.Mock() - mock_details.invocation_metadata = None + mock_context = mock.Mock() + mock_context.invocation_metadata = mock.Mock(return_value=None) interceptor = server_interceptor.OpenCensusServerInterceptor( None, None) with patch: - interceptor.intercept_handler(mock.Mock(), mock_details) + interceptor.intercept_handler( + mock.Mock(), mock.Mock() + ).unary_unary(mock.Mock(), mock_context) expected_attributes = { '/component': 'grpc', @@ -57,13 +66,17 @@ def test_intercept_handler(self): patch = mock.patch( 'opencensus.trace.ext.grpc.server_interceptor.tracer_module.Tracer', MockTracer) - mock_details = mock.Mock() - mock_details.invocation_metadata = (('test_key', b'test_value'),) + mock_context = mock.Mock() + mock_context.invocation_metadata = mock.Mock( + return_value=(('test_key', b'test_value'),) + ) interceptor = server_interceptor.OpenCensusServerInterceptor( None, None) with patch: - interceptor.intercept_handler(mock.Mock(), mock_details) + interceptor.intercept_handler( + mock.Mock(), mock.Mock() + ).unary_unary(mock.Mock(), mock_context) expected_attributes = { '/component': 'grpc', @@ -81,6 +94,36 @@ def test_intercept_service(self): interceptor.intercept_service(None, None) self.assertTrue(mock_handler.called) + def test_intercept_handler_exception(self): + current_span = mock.Mock() + mock_tracer = MockTracer(None, None, None) + patch = mock.patch( + 'opencensus.trace.ext.grpc.server_interceptor.tracer_module.Tracer', + MockTracer) + interceptor = server_interceptor.OpenCensusServerInterceptor( + None, None) + mock_context = mock.Mock() + mock_context.invocation_metadata = mock.Mock(return_value=None) + mock_continuation = mock.Mock() + mock_continuation.unary_unary = mock.Mock(side_effect=Exception('Test')) + with patch: + # patch the wrapper's handler to return an exception + rpc_wrapper = interceptor.intercept_handler( + mock.Mock(), mock.Mock()) + rpc_wrapper.handler.unary_unary = mock.Mock( + side_effect=Exception('Test')) + with self.assertRaises(Exception): + rpc_wrapper.unary_unary(mock.Mock(), mock_context) + + expected_attributes = { + '/component': 'grpc', + '/error/message': 'Test' + } + + self.assertEqual( + execution_context.get_opencensus_tracer().current_span.attributes, + expected_attributes) + class MockTracer(object): def __init__(self, *args, **kwargs):