diff --git a/sentry_sdk/integrations/starlette.py b/sentry_sdk/integrations/starlette.py index 3f78dc4c43..c417b834be 100644 --- a/sentry_sdk/integrations/starlette.py +++ b/sentry_sdk/integrations/starlette.py @@ -73,14 +73,20 @@ class StarletteIntegration(Integration): transaction_style = "" - def __init__(self, transaction_style="url", failed_request_status_codes=None): - # type: (str, Optional[list[HttpStatusCodeRange]]) -> None + def __init__( + self, + transaction_style="url", + failed_request_status_codes=None, + middleware_spans=True, + ): + # type: (str, Optional[list[HttpStatusCodeRange]], bool) -> None if transaction_style not in TRANSACTION_STYLE_VALUES: raise ValueError( "Invalid value for transaction_style: %s (must be in %s)" % (transaction_style, TRANSACTION_STYLE_VALUES) ) self.transaction_style = transaction_style + self.middleware_spans = middleware_spans self.failed_request_status_codes = failed_request_status_codes or [ range(500, 599) ] @@ -110,7 +116,7 @@ def _enable_span_for_middleware(middleware_class): async def _create_span_call(app, scope, receive, send, **kwargs): # type: (Any, Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]], Any) -> None integration = sentry_sdk.get_client().get_integration(StarletteIntegration) - if integration is None: + if integration is None or not integration.middleware_spans: return await old_call(app, scope, receive, send, **kwargs) middleware_name = app.__class__.__name__ diff --git a/tests/integrations/starlette/test_starlette.py b/tests/integrations/starlette/test_starlette.py index 411be72f6f..918ad1185e 100644 --- a/tests/integrations/starlette/test_starlette.py +++ b/tests/integrations/starlette/test_starlette.py @@ -637,20 +637,49 @@ def test_middleware_spans(sentry_init, capture_events): (_, transaction_event) = events - expected = [ + expected_middleware_spans = [ "ServerErrorMiddleware", "AuthenticationMiddleware", "ExceptionMiddleware", + "AuthenticationMiddleware", # 'op': 'middleware.starlette.send' + "ServerErrorMiddleware", # 'op': 'middleware.starlette.send' + "AuthenticationMiddleware", # 'op': 'middleware.starlette.send' + "ServerErrorMiddleware", # 'op': 'middleware.starlette.send' ] + assert len(transaction_event["spans"]) == len(expected_middleware_spans) + idx = 0 for span in transaction_event["spans"]: - if span["op"] == "middleware.starlette": - assert span["description"] == expected[idx] - assert span["tags"]["starlette.middleware_name"] == expected[idx] + if span["op"].startswith("middleware.starlette"): + assert ( + span["tags"]["starlette.middleware_name"] + == expected_middleware_spans[idx] + ) idx += 1 +def test_middleware_spans_disabled(sentry_init, capture_events): + sentry_init( + traces_sample_rate=1.0, + integrations=[StarletteIntegration(middleware_spans=False)], + ) + starlette_app = starlette_app_factory( + middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())] + ) + events = capture_events() + + client = TestClient(starlette_app, raise_server_exceptions=False) + try: + client.get("/message", auth=("Gabriela", "hello123")) + except Exception: + pass + + (_, transaction_event) = events + + assert len(transaction_event["spans"]) == 0 + + def test_middleware_callback_spans(sentry_init, capture_events): sentry_init( traces_sample_rate=1.0,