diff --git a/opencensus/trace/ext/flask/flask_middleware.py b/opencensus/trace/ext/flask/flask_middleware.py index 8c9ddf1cc..2a3db49c5 100644 --- a/opencensus/trace/ext/flask/flask_middleware.py +++ b/opencensus/trace/ext/flask/flask_middleware.py @@ -260,18 +260,21 @@ def _teardown_request(self, exception): if exception is not None: span = execution_context.get_current_span() - span.status = status.Status( - code=code_pb2.UNKNOWN, - message=str(exception) - ) - # try attaching the stack trace to the span, only populated if - # the app has 'PROPAGATE_EXCEPTIONS', 'DEBUG', or 'TESTING' - # enabled - exc_type, _, exc_traceback = sys.exc_info() - if exc_traceback is not None: - span.stack_trace = stack_trace.StackTrace.from_traceback( - exc_traceback + if span is not None: + span.status = status.Status( + code=code_pb2.UNKNOWN, + message=str(exception) ) + # try attaching the stack trace to the span, only populated + # if the app has 'PROPAGATE_EXCEPTIONS', 'DEBUG', or + # 'TESTING' enabled + exc_type, _, exc_traceback = sys.exc_info() + if exc_traceback is not None: + span.stack_trace = ( + stack_trace.StackTrace.from_traceback( + exc_traceback + ) + ) tracer.end_span() tracer.finish() diff --git a/tests/unit/trace/ext/flask/test_flask_middleware.py b/tests/unit/trace/ext/flask/test_flask_middleware.py index fcf2f38cb..3f1aad18c 100644 --- a/tests/unit/trace/ext/flask/test_flask_middleware.py +++ b/tests/unit/trace/ext/flask/test_flask_middleware.py @@ -17,24 +17,32 @@ import unittest +from google.rpc import code_pb2 import flask import mock -from google.rpc import code_pb2 from opencensus.trace import execution_context -from opencensus.trace import span_data from opencensus.trace import span as span_module +from opencensus.trace import span_data from opencensus.trace import stack_trace from opencensus.trace import status -from opencensus.trace.exporters import print_exporter, stackdriver_exporter, \ - zipkin_exporter, jaeger_exporter +from opencensus.trace.blank_span import BlankSpan +from opencensus.trace.exporters import jaeger_exporter +from opencensus.trace.exporters import print_exporter +from opencensus.trace.exporters import stackdriver_exporter +from opencensus.trace.exporters import zipkin_exporter from opencensus.trace.exporters.ocagent import trace_exporter from opencensus.trace.ext.flask import flask_middleware from opencensus.trace.propagation import google_cloud_format from opencensus.trace.samplers import always_off, always_on, ProbabilitySampler +from opencensus.trace.span_context import SpanContext +from opencensus.trace.trace_options import TraceOptions from opencensus.trace.tracers import base from opencensus.trace.tracers import noop_tracer -from opencensus.trace.blank_span import BlankSpan + + +class FlaskTestException(Exception): + pass class TestFlaskMiddleware(unittest.TestCase): @@ -53,7 +61,7 @@ def health_check(): @app.route('/error') def error(): - raise Exception('error') + raise FlaskTestException('error') return app @@ -458,7 +466,7 @@ def test_teardown_include_exception_and_traceback(self): app = self.create_app() app.config['TESTING'] = True flask_middleware.FlaskMiddleware(app=app, exporter=mock_exporter) - with self.assertRaises(Exception): + with self.assertRaises(FlaskTestException): app.test_client().get('/error') exported_spandata = mock_exporter.export.call_args[0][0][0] @@ -471,3 +479,24 @@ def test_teardown_include_exception_and_traceback(self): ) self.assertIsNotNone(exported_spandata.stack_trace.stack_trace_hash_id) self.assertNotEqual(exported_spandata.stack_trace.stack_frames, []) + + def test_teardown_include_exception_and_traceback_span_disabled(self): + sampler = always_off.AlwaysOffSampler() + app = self.create_app() + app.config['TESTING'] = True + middleware = flask_middleware.FlaskMiddleware(app=app, sampler=sampler) + + # TODO: send trace options in header (#465) + original_method = middleware.propagator.from_headers + + def nope(*args, **kwargs): + trace_options = TraceOptions() + trace_options.set_enabled(False) + return SpanContext(trace_options=trace_options) + + middleware.propagator.from_headers = nope + + with self.assertRaises(FlaskTestException): + app.test_client().get('/error') + + middleware.propagator.from_headers = original_method