Skip to content

Commit

Permalink
Refactoring propagation context (#2970)
Browse files Browse the repository at this point in the history
Create a class for the `PropagationContext`. Make the class generate the UUIDs lazily. Fixes #2827
  • Loading branch information
antonpirker committed Apr 25, 2024
1 parent d91a510 commit f5db9ce
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 91 deletions.
105 changes: 27 additions & 78 deletions sentry_sdk/scope.py
@@ -1,6 +1,5 @@
import os
import sys
import uuid
from copy import copy
from collections import deque
from contextlib import contextmanager
Expand All @@ -15,9 +14,9 @@
from sentry_sdk.session import Session
from sentry_sdk.tracing_utils import (
Baggage,
extract_sentrytrace_data,
has_tracing_enabled,
normalize_incoming_data,
PropagationContext,
)
from sentry_sdk.tracing import (
BAGGAGE_HEADER_NAME,
Expand Down Expand Up @@ -196,7 +195,7 @@ def __init__(self, ty=None, client=None):
self._error_processors = [] # type: List[ErrorProcessor]

self._name = None # type: Optional[str]
self._propagation_context = None # type: Optional[Dict[str, Any]]
self._propagation_context = None # type: Optional[PropagationContext]

self.client = NonRecordingClient() # type: sentry_sdk.client.BaseClient

Expand Down Expand Up @@ -431,77 +430,28 @@ def _load_trace_data_from_env(self):

return incoming_trace_information or None

def _extract_propagation_context(self, data):
# type: (Dict[str, Any]) -> Optional[Dict[str, Any]]
context = {} # type: Dict[str, Any]
normalized_data = normalize_incoming_data(data)

baggage_header = normalized_data.get(BAGGAGE_HEADER_NAME)
if baggage_header:
context["dynamic_sampling_context"] = Baggage.from_incoming_header(
baggage_header
).dynamic_sampling_context()

sentry_trace_header = normalized_data.get(SENTRY_TRACE_HEADER_NAME)
if sentry_trace_header:
sentrytrace_data = extract_sentrytrace_data(sentry_trace_header)
if sentrytrace_data is not None:
context.update(sentrytrace_data)

only_baggage_no_sentry_trace = (
"dynamic_sampling_context" in context and "trace_id" not in context
)
if only_baggage_no_sentry_trace:
context.update(self._create_new_propagation_context())

if context:
if not context.get("span_id"):
context["span_id"] = uuid.uuid4().hex[16:]

return context

return None

def _create_new_propagation_context(self):
# type: () -> Dict[str, Any]
return {
"trace_id": uuid.uuid4().hex,
"span_id": uuid.uuid4().hex[16:],
"parent_span_id": None,
"dynamic_sampling_context": None,
}

def set_new_propagation_context(self):
# type: () -> None
"""
Creates a new propagation context and sets it as `_propagation_context`. Overwriting existing one.
"""
self._propagation_context = self._create_new_propagation_context()
logger.debug(
"[Tracing] Create new propagation context: %s",
self._propagation_context,
)
self._propagation_context = PropagationContext()

def generate_propagation_context(self, incoming_data=None):
# type: (Optional[Dict[str, str]]) -> None
"""
Makes sure the propagation context (`_propagation_context`) is set.
The propagation context only lives on the current scope.
If there is `incoming_data` overwrite existing `_propagation_context`.
if there is no `incoming_data` create new `_propagation_context`, but do NOT overwrite if already existing.
Makes sure the propagation context is set on the scope.
If there is `incoming_data` overwrite existing propagation context.
If there is no `incoming_data` create new propagation context, but do NOT overwrite if already existing.
"""
if incoming_data:
context = self._extract_propagation_context(incoming_data)

if context is not None:
self._propagation_context = context
logger.debug(
"[Tracing] Extracted propagation context from incoming data: %s",
self._propagation_context,
)
propagation_context = PropagationContext.from_incoming_data(incoming_data)
if propagation_context is not None:
self._propagation_context = propagation_context

if self._propagation_context is None and self._type != ScopeType.CURRENT:
self.set_new_propagation_context()
if self._type != ScopeType.CURRENT:
if self._propagation_context is None:
self.set_new_propagation_context()

def get_dynamic_sampling_context(self):
# type: () -> Optional[Dict[str, str]]
Expand All @@ -514,11 +464,11 @@ def get_dynamic_sampling_context(self):

baggage = self.get_baggage()
if baggage is not None:
self._propagation_context["dynamic_sampling_context"] = (
self._propagation_context.dynamic_sampling_context = (
baggage.dynamic_sampling_context()
)

return self._propagation_context["dynamic_sampling_context"]
return self._propagation_context.dynamic_sampling_context

def get_traceparent(self, *args, **kwargs):
# type: (Any, Any) -> Optional[str]
Expand All @@ -535,8 +485,8 @@ def get_traceparent(self, *args, **kwargs):
# If this scope has a propagation context, return traceparent from there
if self._propagation_context is not None:
traceparent = "%s-%s" % (
self._propagation_context["trace_id"],
self._propagation_context["span_id"],
self._propagation_context.trace_id,
self._propagation_context.span_id,
)
return traceparent

Expand All @@ -557,8 +507,8 @@ def get_baggage(self, *args, **kwargs):

# If this scope has a propagation context, return baggage from there
if self._propagation_context is not None:
dynamic_sampling_context = self._propagation_context.get(
"dynamic_sampling_context"
dynamic_sampling_context = (
self._propagation_context.dynamic_sampling_context
)
if dynamic_sampling_context is None:
return Baggage.from_options(self)
Expand All @@ -577,9 +527,9 @@ def get_trace_context(self):
return None

trace_context = {
"trace_id": self._propagation_context["trace_id"],
"span_id": self._propagation_context["span_id"],
"parent_span_id": self._propagation_context["parent_span_id"],
"trace_id": self._propagation_context.trace_id,
"span_id": self._propagation_context.span_id,
"parent_span_id": self._propagation_context.parent_span_id,
"dynamic_sampling_context": self.get_dynamic_sampling_context(),
} # type: Dict[str, Any]

Expand Down Expand Up @@ -667,7 +617,7 @@ def iter_trace_propagation_headers(self, *args, **kwargs):
yield header

def get_active_propagation_context(self):
# type: () -> Dict[str, Any]
# type: () -> Optional[PropagationContext]
if self._propagation_context is not None:
return self._propagation_context

Expand All @@ -679,7 +629,7 @@ def get_active_propagation_context(self):
if isolation_scope._propagation_context is not None:
return isolation_scope._propagation_context

return {}
return None

def clear(self):
# type: () -> None
Expand Down Expand Up @@ -1069,12 +1019,11 @@ def start_span(self, instrumenter=INSTRUMENTER.SENTRY, **kwargs):
span = self.span or Scope.get_isolation_scope().span

if span is None:
# New spans get the `trace_id`` from the scope
# New spans get the `trace_id` from the scope
if "trace_id" not in kwargs:

trace_id = self.get_active_propagation_context().get("trace_id")
if trace_id is not None:
kwargs["trace_id"] = trace_id
propagation_context = self.get_active_propagation_context()
if propagation_context is not None:
kwargs["trace_id"] = propagation_context.trace_id

span = Span(**kwargs)
else:
Expand Down
114 changes: 111 additions & 3 deletions sentry_sdk/tracing_utils.py
Expand Up @@ -7,6 +7,7 @@
from datetime import timedelta
from functools import wraps
from urllib.parse import quote, unquote
import uuid

import sentry_sdk
from sentry_sdk.consts import OP, SPANDATA
Expand Down Expand Up @@ -318,6 +319,109 @@ def _format_sql(cursor, sql):
return real_sql or to_string(sql)


class PropagationContext:
"""
The PropagationContext represents the data of a trace in Sentry.
"""

__slots__ = (
"_trace_id",
"_span_id",
"parent_span_id",
"parent_sampled",
"dynamic_sampling_context",
)

def __init__(
self,
trace_id=None, # type: Optional[str]
span_id=None, # type: Optional[str]
parent_span_id=None, # type: Optional[str]
parent_sampled=None, # type: Optional[bool]
dynamic_sampling_context=None, # type: Optional[Dict[str, str]]
):
# type: (...) -> None
self._trace_id = trace_id
"""The trace id of the Sentry trace."""

self._span_id = span_id
"""The span id of the currently executing span."""

self.parent_span_id = parent_span_id
"""The id of the parent span that started this span.
The parent span could also be a span in an upstream service."""

self.parent_sampled = parent_sampled
"""Boolean indicator if the parent span was sampled.
Important when the parent span originated in an upstream service,
because we watn to sample the whole trace, or nothing from the trace."""

self.dynamic_sampling_context = dynamic_sampling_context
"""Data that is used for dynamic sampling decisions."""

@classmethod
def from_incoming_data(cls, incoming_data):
# type: (Dict[str, Any]) -> Optional[PropagationContext]
propagation_context = None

normalized_data = normalize_incoming_data(incoming_data)
baggage_header = normalized_data.get(BAGGAGE_HEADER_NAME)
if baggage_header:
propagation_context = PropagationContext()
propagation_context.dynamic_sampling_context = Baggage.from_incoming_header(
baggage_header
).dynamic_sampling_context()

sentry_trace_header = normalized_data.get(SENTRY_TRACE_HEADER_NAME)
if sentry_trace_header:
sentrytrace_data = extract_sentrytrace_data(sentry_trace_header)
if sentrytrace_data is not None:
if propagation_context is None:
propagation_context = PropagationContext()
propagation_context.update(sentrytrace_data)

return propagation_context

@property
def trace_id(self):
# type: () -> str
"""The trace id of the Sentry trace."""
if not self._trace_id:
self._trace_id = uuid.uuid4().hex

return self._trace_id

@trace_id.setter
def trace_id(self, value):
# type: (str) -> None
self._trace_id = value

@property
def span_id(self):
# type: () -> str
"""The span id of the currently executed span."""
if not self._span_id:
self._span_id = uuid.uuid4().hex[16:]

return self._span_id

@span_id.setter
def span_id(self, value):
# type: (str) -> None
self._span_id = value

def update(self, other_dict):
# type: (Dict[str, Any]) -> None
"""
Updates the PropagationContext with data from the given dictionary.
"""
for key, value in other_dict.items():
try:
setattr(self, key, value)
except AttributeError:
pass


class Baggage:
"""
The W3C Baggage header information (see https://www.w3.org/TR/baggage/).
Expand Down Expand Up @@ -381,8 +485,8 @@ def from_options(cls, scope):
options = client.options
propagation_context = scope._propagation_context

if propagation_context is not None and "trace_id" in propagation_context:
sentry_items["trace_id"] = propagation_context["trace_id"]
if propagation_context is not None:
sentry_items["trace_id"] = propagation_context.trace_id

if options.get("environment"):
sentry_items["environment"] = options["environment"]
Expand Down Expand Up @@ -568,7 +672,11 @@ def get_current_span(scope=None):


# Circular imports
from sentry_sdk.tracing import LOW_QUALITY_TRANSACTION_SOURCES
from sentry_sdk.tracing import (
BAGGAGE_HEADER_NAME,
LOW_QUALITY_TRANSACTION_SOURCES,
SENTRY_TRACE_HEADER_NAME,
)

if TYPE_CHECKING:
from sentry_sdk.tracing import Span
4 changes: 2 additions & 2 deletions tests/integrations/celery/test_celery.py
Expand Up @@ -154,11 +154,11 @@ def dummy_task(x, y):

assert (
error_event["contexts"]["trace"]["trace_id"]
== scope._propagation_context["trace_id"]
== scope._propagation_context.trace_id
)
assert (
error_event["contexts"]["trace"]["span_id"]
!= scope._propagation_context["span_id"]
!= scope._propagation_context.span_id
)
assert error_event["transaction"] == "dummy_task"
assert "celery_task_id" in error_event["tags"]
Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/rq/test_rq.py
Expand Up @@ -190,7 +190,7 @@ def test_tracing_disabled(
assert error_event["transaction"] == "tests.integrations.rq.test_rq.crashing_job"
assert (
error_event["contexts"]["trace"]["trace_id"]
== scope._propagation_context["trace_id"]
== scope._propagation_context.trace_id
)


Expand Down
14 changes: 7 additions & 7 deletions tests/test_api.py
Expand Up @@ -66,8 +66,8 @@ def test_traceparent_with_tracing_disabled(sentry_init):

propagation_context = Scope.get_isolation_scope()._propagation_context
expected_traceparent = "%s-%s" % (
propagation_context["trace_id"],
propagation_context["span_id"],
propagation_context.trace_id,
propagation_context.span_id,
)
assert get_traceparent() == expected_traceparent

Expand All @@ -78,7 +78,7 @@ def test_baggage_with_tracing_disabled(sentry_init):
propagation_context = Scope.get_isolation_scope()._propagation_context
expected_baggage = (
"sentry-trace_id={},sentry-environment=dev,sentry-release=1.0.0".format(
propagation_context["trace_id"]
propagation_context.trace_id
)
)
assert get_baggage() == expected_baggage
Expand Down Expand Up @@ -112,10 +112,10 @@ def test_continue_trace(sentry_init):
assert transaction.name == "some name"

propagation_context = Scope.get_isolation_scope()._propagation_context
assert propagation_context["trace_id"] == transaction.trace_id == trace_id
assert propagation_context["parent_span_id"] == parent_span_id
assert propagation_context["parent_sampled"] == parent_sampled
assert propagation_context["dynamic_sampling_context"] == {
assert propagation_context.trace_id == transaction.trace_id == trace_id
assert propagation_context.parent_span_id == parent_span_id
assert propagation_context.parent_sampled == parent_sampled
assert propagation_context.dynamic_sampling_context == {
"trace_id": "566e3688a61d4bc888951642d6f14a19"
}

Expand Down

0 comments on commit f5db9ce

Please sign in to comment.