Skip to content

Commit

Permalink
More pythonic propagator + package clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanielRN committed Nov 7, 2020
1 parent 3f680d0 commit 18c529a
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 98 deletions.
4 changes: 2 additions & 2 deletions sdk-extension/opentelemetry-sdk-extension-aws/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ package_dir=
=src
packages=find_namespace:
install_requires =
opentelemetry-api == 0.15.b0
opentelemetry-api == 0.16.dev0

[options.entry_points]
opentelemetry_propagator =
aws_xray = opentelemetry.sdk.extension.aws.trace.propagation.aws_xray_format:AwsXRayFormat

[options.extras_require]
test =
opentelemetry-test == 0.15.b0
opentelemetry-test == 0.16.dev0

[options.packages.find]
where = src
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ class AwsXRayFormat(TextMapPropagator):
IS_SAMPLED = "1"
NOT_SAMPLED = "0"

# pylint: disable=too-many-locals
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
def extract(
self,
getter: Getter[TextMapPropagatorT],
Expand All @@ -79,73 +75,78 @@ def extract(
trace.INVALID_SPAN, context=context
)

trace_id, span_id, sampled, err = self.extract_span_properties(
trace_header
)

if err is not None:
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)

options = 0
if sampled:
options |= trace.TraceFlags.SAMPLED

span_context = trace.SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=True,
trace_flags=trace.TraceFlags(options),
trace_state=trace.TraceState(),
)

if not span_context.is_valid:
_logger.error(
"Invalid Span Extracted. Insertting INVALID span into provided context."
)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)

return trace.set_span_in_context(
trace.DefaultSpan(span_context), context=context
)

def extract_span_properties(self, trace_header):
trace_id = trace.INVALID_TRACE_ID
span_id = trace.INVALID_SPAN_ID
sampled = False

next_kv_pair_start = 0
extract_err = None

while next_kv_pair_start < len(trace_header):
try:
kv_pair_delimiter_index = trace_header.index(
self.KV_PAIR_DELIMITER, next_kv_pair_start
)
kv_pair_subset = trace_header[
next_kv_pair_start:kv_pair_delimiter_index
]
next_kv_pair_start = kv_pair_delimiter_index + 1
except ValueError:
kv_pair_subset = trace_header[next_kv_pair_start:]
next_kv_pair_start = len(trace_header)

stripped_kv_pair = kv_pair_subset.strip()
for kv_pair_str in trace_header.split(self.KV_PAIR_DELIMITER):
if extract_err:
break

try:
key_and_value_delimiter_index = stripped_kv_pair.index(
key_str, value_str = kv_pair_str.split(
self.KEY_AND_VALUE_DELIMITER
)
key, value = key_str.strip(), value_str.strip()
except ValueError:
_logger.error(
(
"Error parsing X-Ray trace header. Invalid key value pair: %s. Returning INVALID span context.",
kv_pair_subset,
kv_pair_str,
)
)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
return trace_id, span_id, sampled, extract_err

value = stripped_kv_pair[key_and_value_delimiter_index + 1 :]

if stripped_kv_pair.startswith(self.TRACE_ID_KEY):
if (
len(value) != self.TRACE_ID_LENGTH
or not value.startswith(self.TRACE_ID_VERSION)
or value[self.TRACE_ID_DELIMITER_INDEX_1]
!= self.TRACE_ID_DELIMITER
or value[self.TRACE_ID_DELIMITER_INDEX_2]
!= self.TRACE_ID_DELIMITER
):
if key == self.TRACE_ID_KEY:
if not self.validate_trace_id(value):
_logger.error(
(
"Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
trace_header,
)
)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
extract_err = True
break

timestamp_subset = value[
self.TRACE_ID_DELIMITER_INDEX_1
+ 1 : self.TRACE_ID_DELIMITER_INDEX_2
]
unique_id_subset = value[
self.TRACE_ID_DELIMITER_INDEX_2 + 1 : self.TRACE_ID_LENGTH
]
try:
trace_id = int(timestamp_subset + unique_id_subset, 16)
trace_id = self.parse_trace_id(value)
except ValueError:
_logger.error(
(
Expand All @@ -154,24 +155,21 @@ def extract(
trace_header,
)
)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
elif stripped_kv_pair.startswith(self.PARENT_ID_KEY):
if len(value) != self.PARENT_ID_LENGTH:
extract_err = True
elif key == self.PARENT_ID_KEY:
if not self.validate_span_id(value):
_logger.error(
(
"Invalid ParentId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
trace_header,
)
)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
extract_err = True
break

try:
span_id = int(value, 16)
span_id = AwsXRayFormat.parse_span_id(value)
except ValueError:
_logger.error(
(
Expand All @@ -180,60 +178,61 @@ def extract(
trace_header,
)
)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
elif stripped_kv_pair.startswith(self.SAMPLED_FLAG_KEY):
is_sampled_flag_valid = True

if len(value) != self.SAMPLED_FLAG_LENGTH:
is_sampled_flag_valid = False

if is_sampled_flag_valid:
sampled_flag = value[0]
if sampled_flag == self.IS_SAMPLED:
sampled = True
elif sampled_flag == self.NOT_SAMPLED:
sampled = False
else:
is_sampled_flag_valid = False

if not is_sampled_flag_valid:
extract_err = True
elif key == self.SAMPLED_FLAG_KEY:
if not self.validate_sampled_flag(value):
_logger.error(
(
"Invalid Sampling flag in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
trace_header,
)
)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
extract_err = True
break

options = 0
if sampled:
options |= trace.TraceFlags.SAMPLED
sampled = self.parse_sampled_flag(value)

span_context = trace.SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=True,
trace_flags=trace.TraceFlags(options),
trace_state=trace.TraceState(),
)
return trace_id, span_id, sampled, extract_err

if not span_context.is_valid:
_logger.error(
"Invalid Span Extracted. Insertting INVALID span into provided context."
)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
def validate_trace_id(self, trace_id_str):
return (
len(trace_id_str) == self.TRACE_ID_LENGTH
and trace_id_str.startswith(self.TRACE_ID_VERSION)
and trace_id_str[self.TRACE_ID_DELIMITER_INDEX_1]
== self.TRACE_ID_DELIMITER
and trace_id_str[self.TRACE_ID_DELIMITER_INDEX_2]
== self.TRACE_ID_DELIMITER
)

return trace.set_span_in_context(
trace.DefaultSpan(span_context), context=context
def parse_trace_id(self, trace_id_str):
timestamp_subset = trace_id_str[
self.TRACE_ID_DELIMITER_INDEX_1
+ 1 : self.TRACE_ID_DELIMITER_INDEX_2
]
unique_id_subset = trace_id_str[
self.TRACE_ID_DELIMITER_INDEX_2 + 1 : self.TRACE_ID_LENGTH
]
return int(timestamp_subset + unique_id_subset, 16)

def validate_span_id(self, span_id_str):
return len(span_id_str) == self.PARENT_ID_LENGTH

@staticmethod
def parse_span_id(span_id_str):
return int(span_id_str, 16)

def validate_sampled_flag(self, sampled_flag_str):
return len(
sampled_flag_str
) == self.SAMPLED_FLAG_LENGTH and sampled_flag_str in (
self.IS_SAMPLED,
self.NOT_SAMPLED,
)

def parse_sampled_flag(self, sampled_flag_str):
return sampled_flag_str[0] == self.IS_SAMPLED

def inject(
self,
set_in_carrier: Setter[TextMapPropagatorT],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,41 @@ def test_extract_with_additional_fields(self):
get_extracted_span_context(build_test_context()),
)

def test_extract_with_extra_whitespace(self):
default_xray_trace_header_dict = build_dict_with_xray_trace_header()
trace_header_components = default_xray_trace_header_dict[
AwsXRayFormat.TRACE_HEADER_KEY
].split(AwsXRayFormat.KV_PAIR_DELIMITER)
xray_trace_header_dict_with_extra_whitespace = CaseInsensitiveDict(
{
AwsXRayFormat.TRACE_HEADER_KEY: AwsXRayFormat.KV_PAIR_DELIMITER.join(
[
AwsXRayFormat.KEY_AND_VALUE_DELIMITER.join(
[
" " + key + " ",
" " + value + " ",
]
)
for kv_pair_str in trace_header_components
for key, value in [
kv_pair_str.split(
AwsXRayFormat.KEY_AND_VALUE_DELIMITER
)
]
]
)
}
)
actual_context_encompassing_extracted = AwsXRayPropagatorTest.XRAY_PROPAGATOR.extract(
AwsXRayPropagatorTest.carrier_getter,
xray_trace_header_dict_with_extra_whitespace,
)

self.assertEqual(
get_extracted_span_context(actual_context_encompassing_extracted),
get_extracted_span_context(build_test_context()),
)

def test_extract_invalid_xray_trace_header(self):
actual_context_encompassing_extracted = AwsXRayPropagatorTest.XRAY_PROPAGATOR.extract(
AwsXRayPropagatorTest.carrier_getter,
Expand Down

0 comments on commit 18c529a

Please sign in to comment.