From 4c813c47e4efa9432ea9176512f13f724d215fd6 Mon Sep 17 00:00:00 2001 From: Jakub Wach Date: Fri, 7 Jan 2022 22:12:58 +0100 Subject: [PATCH] ASGI: Conditionally create SERVER spans (#843) --- CHANGELOG.md | 4 ++ .../instrumentation/asgi/__init__.py | 45 ++++++++++++------ .../tests/test_fastapi_instrumentation.py | 46 +++++++++++++++++++ 3 files changed, 80 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee27a96ba8..dcf9a79be7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `opentelemetry-instrumentation-flask` Flask: Conditionally create SERVER spans ([#828](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/828)) +- `opentelemetry-instrumentation-asgi` ASGI: Conditionally create SERVER spans + ([#843](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/843)) + + ## [1.8.0-0.27b0](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.8.0-0.27b0) - 2021-12-17 ### Added diff --git a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py index a953165473..8b034b95f0 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=too-many-locals """ The opentelemetry-instrumentation-asgi package provides an ASGI middleware that can be used @@ -110,7 +111,12 @@ def client_response_hook(span: Span, message: dict): from opentelemetry.propagate import extract from opentelemetry.propagators.textmap import Getter, Setter from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace import Span, set_span_in_context +from opentelemetry.trace import ( + INVALID_SPAN, + Span, + SpanKind, + set_span_in_context, +) from opentelemetry.trace.status import Status, StatusCode from opentelemetry.util.http import remove_url_credentials @@ -321,39 +327,48 @@ async def __call__(self, scope, receive, send): if self.excluded_urls and self.excluded_urls.url_disabled(url): return await self.app(scope, receive, send) - token = context.attach(extract(scope, getter=asgi_getter)) - server_span_name, additional_attributes = self.default_span_details( - scope - ) + token = ctx = span_kind = None + + if trace.get_current_span() is INVALID_SPAN: + ctx = extract(scope, getter=asgi_getter) + token = context.attach(ctx) + span_kind = SpanKind.SERVER + else: + ctx = context.get_current() + span_kind = SpanKind.INTERNAL + + span_name, additional_attributes = self.default_span_details(scope) try: with self.tracer.start_as_current_span( - server_span_name, - kind=trace.SpanKind.SERVER, - ) as server_span: - if server_span.is_recording(): + span_name, + context=ctx, + kind=span_kind, + ) as current_span: + if current_span.is_recording(): attributes = collect_request_attributes(scope) attributes.update(additional_attributes) for key, value in attributes.items(): - server_span.set_attribute(key, value) + current_span.set_attribute(key, value) if callable(self.server_request_hook): - self.server_request_hook(server_span, scope) + self.server_request_hook(current_span, scope) otel_receive = self._get_otel_receive( - server_span_name, scope, receive + span_name, scope, receive ) otel_send = self._get_otel_send( - server_span, - server_span_name, + current_span, + span_name, scope, send, ) await self.app(scope, otel_receive, otel_send) finally: - context.detach(token) + if token: + context.detach(token) def _get_otel_receive(self, server_span_name, scope, receive): @wraps(receive) diff --git a/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py b/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py index 97bd1e9c9e..ae963e4f87 100644 --- a/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py +++ b/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py @@ -19,6 +19,7 @@ from fastapi.testclient import TestClient import opentelemetry.instrumentation.fastapi as otel_fastapi +from opentelemetry import trace from opentelemetry.instrumentation.asgi import OpenTelemetryMiddleware from opentelemetry.sdk.resources import Resource from opentelemetry.semconv.trace import SpanAttributes @@ -329,3 +330,48 @@ def test_instrumentation(self): should_be_original = fastapi.FastAPI self.assertIs(original, should_be_original) + + +class TestWrappedApplication(TestBase): + def setUp(self): + super().setUp() + + self.app = fastapi.FastAPI() + + @self.app.get("/foobar") + async def _(): + return {"message": "hello world"} + + otel_fastapi.FastAPIInstrumentor().instrument_app(self.app) + self.client = TestClient(self.app) + self.tracer = self.tracer_provider.get_tracer(__name__) + + def tearDown(self) -> None: + super().tearDown() + with self.disable_logging(): + otel_fastapi.FastAPIInstrumentor().uninstrument_app(self.app) + + def test_mark_span_internal_in_presence_of_span_from_other_framework(self): + with self.tracer.start_as_current_span( + "test", kind=trace.SpanKind.SERVER + ) as parent_span: + resp = self.client.get("/foobar") + self.assertEqual(200, resp.status_code) + + span_list = self.memory_exporter.get_finished_spans() + for span in span_list: + print(str(span.__class__) + ": " + str(span.__dict__)) + + # there should be 4 spans - single SERVER "test" and three INTERNAL "FastAPI" + self.assertEqual(trace.SpanKind.INTERNAL, span_list[0].kind) + self.assertEqual(trace.SpanKind.INTERNAL, span_list[1].kind) + # main INTERNAL span - child of test + self.assertEqual(trace.SpanKind.INTERNAL, span_list[2].kind) + self.assertEqual( + parent_span.context.span_id, span_list[2].parent.span_id + ) + # SERVER "test" + self.assertEqual(trace.SpanKind.SERVER, span_list[3].kind) + self.assertEqual( + parent_span.context.span_id, span_list[3].context.span_id + )