Skip to content

Commit

Permalink
Expose AWS Progagator variables and update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanielRN committed Nov 9, 2020
1 parent ea0988a commit dcc8f04
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 97 deletions.
2 changes: 1 addition & 1 deletion sdk-extension/opentelemetry-sdk-extension-aws/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Propagator:

::

export OTEL_PYTHON_PROPAGATORS = aws_xray
export OTEL_PROPAGATORS = aws_xray


References
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,27 @@
TextMapPropagatorT,
)

TRACE_HEADER_KEY = "X-Amzn-Trace-Id"
KV_PAIR_DELIMITER = ";"
KEY_AND_VALUE_DELIMITER = "="

TRACE_ID_KEY = "Root"
TRACE_ID_LENGTH = 35
TRACE_ID_VERSION = "1"
TRACE_ID_DELIMITER = "-"
TRACE_ID_DELIMITER_INDEX_1 = 1
TRACE_ID_DELIMITER_INDEX_2 = 10
TRACE_ID_FIRST_PART_LENGTH = 8

PARENT_ID_KEY = "Parent"
PARENT_ID_LENGTH = 16

SAMPLED_FLAG_KEY = "Sampled"
SAMPLED_FLAG_LENGTH = 1
IS_SAMPLED = "1"
NOT_SAMPLED = "0"


_logger = logging.getLogger(__name__)


Expand All @@ -40,35 +61,14 @@ class AwsXRayFormat(TextMapPropagator):
"""

# AWS
TRACE_HEADER_KEY = "X-Amzn-Trace-Id"

KV_PAIR_DELIMITER = ";"
KEY_AND_VALUE_DELIMITER = "="

TRACE_ID_KEY = "Root"
TRACE_ID_LENGTH = 35
TRACE_ID_VERSION = "1"
TRACE_ID_DELIMITER = "-"
TRACE_ID_DELIMITER_INDEX_1 = 1
TRACE_ID_DELIMITER_INDEX_2 = 10
TRACE_ID_FIRST_PART_LENGTH = 8

PARENT_ID_KEY = "Parent"
PARENT_ID_LENGTH = 16

SAMPLED_FLAG_KEY = "Sampled"
SAMPLED_FLAG_LENGTH = 1
IS_SAMPLED = "1"
NOT_SAMPLED = "0"

def extract(
self,
getter: Getter[TextMapPropagatorT],
carrier: TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
trace_header_list = getter(carrier, self.TRACE_HEADER_KEY)
trace_header_list = getter.get(carrier, self.TRACE_HEADER_KEY)
trace_header_list = getter.get(carrier, TRACE_HEADER_KEY)

if not trace_header_list or len(trace_header_list) != 1:
return trace.set_span_in_context(
Expand All @@ -83,9 +83,11 @@ def extract(
)

try:
trace_id, span_id, sampled = self._extract_span_properties(
trace_header
)
(
trace_id,
span_id,
sampled,
) = AwsXRayFormat._extract_span_properties(trace_header)
except AwsParseTraceHeaderError as err:
_logger.debug(err.message)
return trace.set_span_in_context(
Expand Down Expand Up @@ -116,16 +118,15 @@ def extract(
trace.DefaultSpan(span_context), context=context
)

def _extract_span_properties(self, trace_header):
@staticmethod
def _extract_span_properties(trace_header):
trace_id = trace.INVALID_TRACE_ID
span_id = trace.INVALID_SPAN_ID
sampled = False

for kv_pair_str in trace_header.split(self.KV_PAIR_DELIMITER):
for kv_pair_str in trace_header.split(KV_PAIR_DELIMITER):
try:
key_str, value_str = kv_pair_str.split(
self.KEY_AND_VALUE_DELIMITER
)
key_str, value_str = kv_pair_str.split(KEY_AND_VALUE_DELIMITER)
key, value = key_str.strip(), value_str.strip()
except ValueError as ex:
raise AwsParseTraceHeaderError(
Expand All @@ -134,32 +135,32 @@ def _extract_span_properties(self, trace_header):
kv_pair_str,
)
) from ex
if key == self.TRACE_ID_KEY:
if not self._validate_trace_id(value):
if key == TRACE_ID_KEY:
if not AwsXRayFormat._validate_trace_id(value):
raise AwsParseTraceHeaderError(
(
"Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
TRACE_HEADER_KEY,
trace_header,
)
)

try:
trace_id = self._parse_trace_id(value)
trace_id = AwsXRayFormat._parse_trace_id(value)
except ValueError as ex:
raise AwsParseTraceHeaderError(
(
"Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
TRACE_HEADER_KEY,
trace_header,
)
) from ex
elif key == self.PARENT_ID_KEY:
if not self._validate_span_id(value):
elif key == PARENT_ID_KEY:
if not AwsXRayFormat._validate_span_id(value):
raise AwsParseTraceHeaderError(
(
"Invalid ParentId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
TRACE_HEADER_KEY,
trace_header,
)
)
Expand All @@ -170,61 +171,63 @@ def _extract_span_properties(self, trace_header):
raise AwsParseTraceHeaderError(
(
"Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
TRACE_HEADER_KEY,
trace_header,
)
) from ex
elif key == self.SAMPLED_FLAG_KEY:
if not self._validate_sampled_flag(value):
elif key == SAMPLED_FLAG_KEY:
if not AwsXRayFormat._validate_sampled_flag(value):
raise AwsParseTraceHeaderError(
(
"Invalid Sampling flag in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
TRACE_HEADER_KEY,
trace_header,
)
)

sampled = self._parse_sampled_flag(value)
sampled = AwsXRayFormat._parse_sampled_flag(value)

return trace_id, span_id, sampled

def _validate_trace_id(self, trace_id_str):
@staticmethod
def _validate_trace_id(trace_id_str):
return (
len(trace_id_str) == self.TRACE_ID_LENGTH
and trace_id_str.startswith(self.TRACE_ID_VERSION)
and trace_id_str[self.TRACE_ID_DELIMITER_INDEX_1]
== self.TRACE_ID_DELIMITER
and trace_id_str[self.TRACE_ID_DELIMITER_INDEX_2]
== self.TRACE_ID_DELIMITER
len(trace_id_str) == TRACE_ID_LENGTH
and trace_id_str.startswith(TRACE_ID_VERSION)
and trace_id_str[TRACE_ID_DELIMITER_INDEX_1] == TRACE_ID_DELIMITER
and trace_id_str[TRACE_ID_DELIMITER_INDEX_2] == TRACE_ID_DELIMITER
)

def _parse_trace_id(self, trace_id_str):
@staticmethod
def _parse_trace_id(trace_id_str):
timestamp_subset = trace_id_str[
self.TRACE_ID_DELIMITER_INDEX_1
+ 1 : self.TRACE_ID_DELIMITER_INDEX_2
TRACE_ID_DELIMITER_INDEX_1 + 1 : TRACE_ID_DELIMITER_INDEX_2
]
unique_id_subset = trace_id_str[
self.TRACE_ID_DELIMITER_INDEX_2 + 1 : self.TRACE_ID_LENGTH
TRACE_ID_DELIMITER_INDEX_2 + 1 : TRACE_ID_LENGTH
]
return int(timestamp_subset + unique_id_subset, 16)

def _validate_span_id(self, span_id_str):
return len(span_id_str) == self.PARENT_ID_LENGTH
@staticmethod
def _validate_span_id(span_id_str):
return len(span_id_str) == PARENT_ID_LENGTH

@staticmethod
def _parse_span_id(span_id_str):
return int(span_id_str, 16)

def _validate_sampled_flag(self, sampled_flag_str):
@staticmethod
def _validate_sampled_flag(sampled_flag_str):
return len(
sampled_flag_str
) == self.SAMPLED_FLAG_LENGTH and sampled_flag_str in (
self.IS_SAMPLED,
self.NOT_SAMPLED,
) == SAMPLED_FLAG_LENGTH and sampled_flag_str in (
IS_SAMPLED,
NOT_SAMPLED,
)

def _parse_sampled_flag(self, sampled_flag_str):
return sampled_flag_str[0] == self.IS_SAMPLED
@staticmethod
def _parse_sampled_flag(sampled_flag_str):
return sampled_flag_str[0] == IS_SAMPLED

def inject(
self,
Expand All @@ -240,37 +243,37 @@ def inject(

otel_trace_id = "{:032x}".format(span_context.trace_id)
xray_trace_id = (
self.TRACE_ID_VERSION
+ self.TRACE_ID_DELIMITER
+ otel_trace_id[: self.TRACE_ID_FIRST_PART_LENGTH]
+ self.TRACE_ID_DELIMITER
+ otel_trace_id[self.TRACE_ID_FIRST_PART_LENGTH :]
TRACE_ID_VERSION
+ TRACE_ID_DELIMITER
+ otel_trace_id[:TRACE_ID_FIRST_PART_LENGTH]
+ TRACE_ID_DELIMITER
+ otel_trace_id[TRACE_ID_FIRST_PART_LENGTH:]
)

parent_id = "{:016x}".format(span_context.span_id)

sampling_flag = (
self.IS_SAMPLED
IS_SAMPLED
if span_context.trace_flags & trace.TraceFlags.SAMPLED
else self.NOT_SAMPLED
else NOT_SAMPLED
)

# TODO: Add OT trace state to the X-Ray trace header

trace_header = (
self.TRACE_ID_KEY
+ self.KEY_AND_VALUE_DELIMITER
TRACE_ID_KEY
+ KEY_AND_VALUE_DELIMITER
+ xray_trace_id
+ self.KV_PAIR_DELIMITER
+ self.PARENT_ID_KEY
+ self.KEY_AND_VALUE_DELIMITER
+ KV_PAIR_DELIMITER
+ PARENT_ID_KEY
+ KEY_AND_VALUE_DELIMITER
+ parent_id
+ self.KV_PAIR_DELIMITER
+ self.SAMPLED_FLAG_KEY
+ self.KEY_AND_VALUE_DELIMITER
+ KV_PAIR_DELIMITER
+ SAMPLED_FLAG_KEY
+ KEY_AND_VALUE_DELIMITER
+ sampling_flag
)

set_in_carrier(
carrier, self.TRACE_HEADER_KEY, trace_header,
carrier, TRACE_HEADER_KEY, trace_header,
)
Loading

0 comments on commit dcc8f04

Please sign in to comment.