Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions sentry_sdk/integrations/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
)
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.traces import NoOpStreamedSpan, StreamedSpan
from sentry_sdk.tracing import (
SOURCE_FOR_STYLE,
TransactionSource,
)
from sentry_sdk.tracing_utils import has_span_streaming_enabled
from sentry_sdk.utils import (
AnnotatedValue,
capture_internal_exceptions,
Expand Down Expand Up @@ -147,7 +149,8 @@ async def _create_span_call(
send: "Callable[[Dict[str, Any]], Awaitable[None]]",
**kwargs: "Any",
) -> None:
integration = sentry_sdk.get_client().get_integration(StarletteIntegration)
client = sentry_sdk.get_client()
integration = client.get_integration(StarletteIntegration)
if integration is None:
return await old_call(app, scope, receive, send, **kwargs)

Expand All @@ -164,22 +167,38 @@ async def _create_span_call(
return await old_call(app, scope, receive, send, **kwargs)

middleware_name = app.__class__.__name__
is_span_streaming_enabled = has_span_streaming_enabled(client.options)

def _start_middleware_span(op: str, name: str) -> "Any":
if is_span_streaming_enabled:
return sentry_sdk.traces.start_span(
name=name,
attributes={
"sentry.op": op,
"sentry.origin": StarletteIntegration.origin,
"starlette.middleware_name": middleware_name,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

starlette.middleware_name is not in Sentry semantic conventions -- anything we set as an attribute on a streamed span needs to be there. So we either need to add it or not set it at all.

In this case, if I read the code right, it seems to be the same as the span name? In that case I'd omit the attribute. If we want to keep it, I'd propose a new attribute name in Sentry conventions that's not tied to Starlette specifically, so that we can use it across different frameworks and languages.

},
)
return sentry_sdk.start_span(
op=op,
name=name,
origin=StarletteIntegration.origin,
)

with sentry_sdk.start_span(
op=OP.MIDDLEWARE_STARLETTE,
name=middleware_name,
origin=StarletteIntegration.origin,
with _start_middleware_span(
op=OP.MIDDLEWARE_STARLETTE, name=middleware_name
) as middleware_span:
middleware_span.set_tag("starlette.middleware_name", middleware_name)
if not is_span_streaming_enabled:
middleware_span.set_tag("starlette.middleware_name", middleware_name)

# Creating spans for the "receive" callback
async def _sentry_receive(*args: "Any", **kwargs: "Any") -> "Any":
with sentry_sdk.start_span(
with _start_middleware_span(
op=OP.MIDDLEWARE_STARLETTE_RECEIVE,
name=getattr(receive, "__qualname__", str(receive)),
origin=StarletteIntegration.origin,
) as span:
span.set_tag("starlette.middleware_name", middleware_name)
if not is_span_streaming_enabled:
span.set_tag("starlette.middleware_name", middleware_name)
return await receive(*args, **kwargs)

receive_name = getattr(receive, "__name__", str(receive))
Expand All @@ -188,12 +207,12 @@ async def _sentry_receive(*args: "Any", **kwargs: "Any") -> "Any":

# Creating spans for the "send" callback
async def _sentry_send(*args: "Any", **kwargs: "Any") -> "Any":
with sentry_sdk.start_span(
with _start_middleware_span(
op=OP.MIDDLEWARE_STARLETTE_SEND,
name=getattr(send, "__qualname__", str(send)),
origin=StarletteIntegration.origin,
) as span:
span.set_tag("starlette.middleware_name", middleware_name)
if not is_span_streaming_enabled:
span.set_tag("starlette.middleware_name", middleware_name)
return await send(*args, **kwargs)

send_name = getattr(send, "__name__", str(send))
Expand Down Expand Up @@ -496,7 +515,13 @@ def _sentry_sync_func(*args: "Any", **kwargs: "Any") -> "Any":
return old_func(*args, **kwargs)

current_scope = sentry_sdk.get_current_scope()
if current_scope.transaction is not None:
current_span = current_scope.span

if isinstance(current_span, StreamedSpan) and not isinstance(
current_span, NoOpStreamedSpan
):
current_span._segment._update_active_thread()
elif current_scope.transaction is not None:
current_scope.transaction.update_active_thread()

sentry_scope = sentry_sdk.get_isolation_scope()
Expand Down
192 changes: 159 additions & 33 deletions tests/integrations/starlette/test_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pytest

import sentry_sdk
from sentry_sdk import capture_message, get_baggage, get_traceparent
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from sentry_sdk.integrations.starlette import (
Expand Down Expand Up @@ -648,24 +649,30 @@ def test_user_information_transaction_no_pii(sentry_init, capture_events):
assert "user" not in transaction_event


def test_middleware_spans(sentry_init, capture_events):
@pytest.mark.parametrize("span_streaming", [True, False])
def test_middleware_spans(sentry_init, capture_events, capture_items, span_streaming):
sentry_init(
traces_sample_rate=1.0,
integrations=[StarletteIntegration(middleware_spans=True)],
_experiments={
"trace_lifecycle": "stream" if span_streaming else "static",
},
)
starlette_app = starlette_app_factory(
middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())]
)
events = capture_events()

if span_streaming:
items = capture_items("span")
else:
events = capture_events()

client = TestClient(starlette_app, raise_server_exceptions=False)
try:
client.get("/message", auth=("Gabriela", "hello123"))
except Exception:
pass

(_, transaction_event) = events

expected_middleware_spans = [
"ServerErrorMiddleware",
"AuthenticationMiddleware",
Expand All @@ -676,55 +683,108 @@ def test_middleware_spans(sentry_init, capture_events):
"ServerErrorMiddleware", # 'op': 'middleware.starlette.send'
]

assert len(transaction_event["spans"]) == len(expected_middleware_spans)
if span_streaming:
sentry_sdk.flush()

middleware_spans = sorted(
[
item.payload
for item in items
if item.payload.get("attributes", {})
.get("sentry.op", "")
.startswith("middleware.starlette")
],
key=lambda s: s["start_timestamp"],
)

idx = 0
for span in transaction_event["spans"]:
if span["op"].startswith("middleware.starlette"):
assert len(middleware_spans) == len(expected_middleware_spans)

for idx, span in enumerate(middleware_spans):
assert (
span["tags"]["starlette.middleware_name"]
span["attributes"]["starlette.middleware_name"]
== expected_middleware_spans[idx]
)
idx += 1
else:
(_, transaction_event) = events

assert len(transaction_event["spans"]) == len(expected_middleware_spans)

idx = 0
for span in transaction_event["spans"]:
if span["op"].startswith("middleware.starlette"):
assert (
span["tags"]["starlette.middleware_name"]
== expected_middleware_spans[idx]
)
idx += 1

def test_middleware_spans_disabled(sentry_init, capture_events):

@pytest.mark.parametrize("span_streaming", [True, False])
def test_middleware_spans_disabled(
sentry_init, capture_events, capture_items, span_streaming
):
sentry_init(
traces_sample_rate=1.0,
integrations=[StarletteIntegration(middleware_spans=False)],
_experiments={
"trace_lifecycle": "stream" if span_streaming else "static",
},
)
starlette_app = starlette_app_factory(
middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())]
)
events = capture_events()

if span_streaming:
items = capture_items("span")
else:
events = capture_events()

client = TestClient(starlette_app, raise_server_exceptions=False)
try:
client.get("/message", auth=("Gabriela", "hello123"))
except Exception:
pass

(_, transaction_event) = events

assert len(transaction_event["spans"]) == 0
if span_streaming:
sentry_sdk.flush()

middleware_spans = [
item.payload
for item in items
if item.payload.get("attributes", {})
.get("sentry.op", "")
.startswith("middleware.starlette")
]
assert len(middleware_spans) == 0
else:
(_, transaction_event) = events
assert len(transaction_event["spans"]) == 0


def test_middleware_callback_spans(sentry_init, capture_events):
@pytest.mark.parametrize("span_streaming", [True, False])
def test_middleware_callback_spans(
sentry_init, capture_events, capture_items, span_streaming
):
sentry_init(
traces_sample_rate=1.0,
integrations=[StarletteIntegration()],
integrations=[StarletteIntegration(middleware_spans=True)],
_experiments={
"trace_lifecycle": "stream" if span_streaming else "static",
},
)
starlette_app = starlette_app_factory(middleware=[Middleware(SampleMiddleware)])
events = capture_events()

if span_streaming:
items = capture_items("span")
else:
events = capture_events()

client = TestClient(starlette_app, raise_server_exceptions=False)
try:
client.get("/message", auth=("Gabriela", "hello123"))
except Exception:
pass

(_, transaction_event) = events

expected = [
{
"op": "middleware.starlette",
Expand Down Expand Up @@ -773,12 +833,37 @@ def test_middleware_callback_spans(sentry_init, capture_events):
},
]

idx = 0
for span in transaction_event["spans"]:
assert span["op"] == expected[idx]["op"]
assert span["description"] == expected[idx]["description"]
assert span["tags"] == expected[idx]["tags"]
idx += 1
if span_streaming:
sentry_sdk.flush()

middleware_spans = sorted(
[
item.payload
for item in items
if item.payload.get("attributes", {})
.get("sentry.op", "")
.startswith("middleware.starlette")
],
key=lambda s: s["start_timestamp"],
)

assert len(middleware_spans) == len(expected)
for span, exp in zip(middleware_spans, expected):
assert span["attributes"]["sentry.op"] == exp["op"]
assert span["name"] == exp["description"]
assert (
span["attributes"]["starlette.middleware_name"]
== exp["tags"]["starlette.middleware_name"]
)
else:
(_, transaction_event) = events

idx = 0
for span in transaction_event["spans"]:
assert span["op"] == expected[idx]["op"]
assert span["description"] == expected[idx]["description"]
assert span["tags"] == expected[idx]["tags"]
idx += 1


def test_middleware_receive_send(sentry_init, capture_events):
Expand Down Expand Up @@ -946,6 +1031,31 @@ def test_active_thread_id(sentry_init, capture_envelopes, teardown_profiling, en
assert str(data["active"]) == trace_context["data"]["thread.id"]


@pytest.mark.parametrize("endpoint", ["/sync/thread_ids", "/async/thread_ids"])
def test_active_thread_id_span_streaming(sentry_init, capture_items, endpoint):
sentry_init(
auto_enabling_integrations=False, # avoid legacy spans from auto-enabled integrations leaking into streaming mode
integrations=[StarletteIntegration()],
traces_sample_rate=1.0,
_experiments={"trace_lifecycle": "stream"},
)
app = starlette_app_factory()

items = capture_items("span")

client = TestClient(app)
response = client.get(endpoint)
assert response.status_code == 200

data = json.loads(response.content)

sentry_sdk.flush()

segments = [item.payload for item in items if item.payload.get("is_segment")]
assert len(segments) == 1
assert str(data["active"]) == segments[0]["attributes"]["thread.id"]


def test_original_request_not_scrubbed(sentry_init, capture_events):
sentry_init(integrations=[StarletteIntegration()])

Expand Down Expand Up @@ -1167,27 +1277,43 @@ def test_transaction_name_in_middleware(
)


def test_span_origin(sentry_init, capture_events):
@pytest.mark.parametrize("span_streaming", [True, False])
def test_span_origin(sentry_init, capture_events, capture_items, span_streaming):
sentry_init(
integrations=[StarletteIntegration()],
auto_enabling_integrations=False, # avoid httpx auto-instrumentation leaking spans
integrations=[StarletteIntegration(middleware_spans=True)],
traces_sample_rate=1.0,
_experiments={
"trace_lifecycle": "stream" if span_streaming else "static",
},
)
starlette_app = starlette_app_factory(
middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())]
)
events = capture_events()

if span_streaming:
items = capture_items("span")
else:
events = capture_events()

client = TestClient(starlette_app, raise_server_exceptions=False)
try:
client.get("/message", auth=("Gabriela", "hello123"))
except Exception:
pass

(_, event) = events
if span_streaming:
sentry_sdk.flush()

assert len(items) > 0
for item in items:
assert item.payload["attributes"]["sentry.origin"] == "auto.http.starlette"
else:
(_, event) = events

assert event["contexts"]["trace"]["origin"] == "auto.http.starlette"
for span in event["spans"]:
assert span["origin"] == "auto.http.starlette"
assert event["contexts"]["trace"]["origin"] == "auto.http.starlette"
for span in event["spans"]:
assert span["origin"] == "auto.http.starlette"


class NonIterableContainer:
Expand Down
Loading