Skip to content

Commit

Permalink
Use an exception to catch inability to parse
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanielRN committed Nov 8, 2020
1 parent d9a2c8c commit 7dee36c
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import random
import time

from opentelemetry import trace

Expand All @@ -28,10 +28,13 @@ class AwsXRayIdsGenerator(trace.IdsGenerator):
See: https://docs.aws.amazon.com/xray/latest/devguide/xray-api-sendingdata.html#xray-api-traceids
"""

random_ids_generator = trace.RandomIdsGenerator()

def generate_span_id(self) -> int:
return trace.RandomIdsGenerator().generate_span_id()
return self.random_ids_generator.generate_span_id()

def generate_trace_id(self) -> int:
trace_time = int(datetime.datetime.utcnow().timestamp())
@staticmethod
def generate_trace_id() -> int:
trace_time = int(time.time())
trace_identifier = random.getrandbits(96)
return (trace_time << 96) + trace_identifier
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
_logger = logging.getLogger(__name__)


class AwsParseTraceHeaderError(Exception):
def __init__(self, message):
super().__init__()
self.message = message


class AwsXRayFormat(TextMapPropagator):
"""Propagator for the AWS X-Ray Trace Header propagation protocol.
Expand Down Expand Up @@ -75,11 +81,12 @@ def extract(
trace.INVALID_SPAN, context=context
)

trace_id, span_id, sampled, err = self.extract_span_properties(
trace_header
)

if err is not None:
try:
trace_id, span_id, sampled = self._extract_span_properties(
trace_header
)
except AwsParseTraceHeaderError as err:
_logger.debug(err.message)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
Expand All @@ -97,7 +104,7 @@ def extract(
)

if not span_context.is_valid:
_logger.error(
_logger.debug(
"Invalid Span Extracted. Insertting INVALID span into provided context."
)
return trace.set_span_in_context(
Expand All @@ -108,94 +115,79 @@ def extract(
trace.DefaultSpan(span_context), context=context
)

def extract_span_properties(self, trace_header):
def _extract_span_properties(self, trace_header):
trace_id = trace.INVALID_TRACE_ID
span_id = trace.INVALID_SPAN_ID
sampled = False

extract_err = None

for kv_pair_str in trace_header.split(self.KV_PAIR_DELIMITER):
if extract_err:
break

try:
key_str, value_str = kv_pair_str.split(
self.KEY_AND_VALUE_DELIMITER
)
key, value = key_str.strip(), value_str.strip()
except ValueError:
_logger.error(
except ValueError as ex:
raise AwsParseTraceHeaderError(
(
"Error parsing X-Ray trace header. Invalid key value pair: %s. Returning INVALID span context.",
kv_pair_str,
)
)
return trace_id, span_id, sampled, extract_err

) from ex
if key == self.TRACE_ID_KEY:
if not self.validate_trace_id(value):
_logger.error(
if not self._validate_trace_id(value):
raise AwsParseTraceHeaderError(
(
"Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
trace_header,
)
)
extract_err = True
break

try:
trace_id = self.parse_trace_id(value)
except ValueError:
_logger.error(
trace_id = self._parse_trace_id(value)
except ValueError as ex:
raise AwsParseTraceHeaderError(
(
"Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
trace_header,
)
)
extract_err = True
) from ex
elif key == self.PARENT_ID_KEY:
if not self.validate_span_id(value):
_logger.error(
if not self._validate_span_id(value):
raise AwsParseTraceHeaderError(
(
"Invalid ParentId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
trace_header,
)
)
extract_err = True
break

try:
span_id = AwsXRayFormat.parse_span_id(value)
except ValueError:
_logger.error(
span_id = AwsXRayFormat._parse_span_id(value)
except ValueError as ex:
raise AwsParseTraceHeaderError(
(
"Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
trace_header,
)
)
extract_err = True
) from ex
elif key == self.SAMPLED_FLAG_KEY:
if not self.validate_sampled_flag(value):
_logger.error(
if not self._validate_sampled_flag(value):
raise AwsParseTraceHeaderError(
(
"Invalid Sampling flag in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY,
trace_header,
)
)
extract_err = True
break

sampled = self.parse_sampled_flag(value)
sampled = self._parse_sampled_flag(value)

return trace_id, span_id, sampled, extract_err
return trace_id, span_id, sampled

def validate_trace_id(self, trace_id_str):
def _validate_trace_id(self, trace_id_str):
return (
len(trace_id_str) == self.TRACE_ID_LENGTH
and trace_id_str.startswith(self.TRACE_ID_VERSION)
Expand All @@ -205,7 +197,7 @@ def validate_trace_id(self, trace_id_str):
== self.TRACE_ID_DELIMITER
)

def parse_trace_id(self, trace_id_str):
def _parse_trace_id(self, trace_id_str):
timestamp_subset = trace_id_str[
self.TRACE_ID_DELIMITER_INDEX_1
+ 1 : self.TRACE_ID_DELIMITER_INDEX_2
Expand All @@ -215,22 +207,22 @@ def parse_trace_id(self, trace_id_str):
]
return int(timestamp_subset + unique_id_subset, 16)

def validate_span_id(self, span_id_str):
def _validate_span_id(self, span_id_str):
return len(span_id_str) == self.PARENT_ID_LENGTH

@staticmethod
def parse_span_id(span_id_str):
def _parse_span_id(span_id_str):
return int(span_id_str, 16)

def validate_sampled_flag(self, sampled_flag_str):
def _validate_sampled_flag(self, sampled_flag_str):
return len(
sampled_flag_str
) == self.SAMPLED_FLAG_LENGTH and sampled_flag_str in (
self.IS_SAMPLED,
self.NOT_SAMPLED,
)

def parse_sampled_flag(self, sampled_flag_str):
def _parse_sampled_flag(self, sampled_flag_str):
return sampled_flag_str[0] == self.IS_SAMPLED

def inject(
Expand Down
13 changes: 0 additions & 13 deletions sdk-extension/opentelemetry-sdk-extension-aws/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +0,0 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -1,13 +0,0 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import datetime
import time
import unittest

from opentelemetry.sdk.extension.aws.trace import AwsXRayIdsGenerator
Expand All @@ -33,11 +34,9 @@ def test_id_timestamps_are_acceptable_for_xray(self):
for _ in range(1000):
trace_id = ids_generator.generate_trace_id()
trace_id_time = trace_id >> 96
current_time = int(datetime.datetime.utcnow().timestamp())
current_time = int(time.time())
self.assertLessEqual(trace_id_time, current_time)
one_month_ago_time = int(
(
datetime.datetime.utcnow() - datetime.timedelta(30)
).timestamp()
(datetime.datetime.now() - datetime.timedelta(30)).timestamp()
)
self.assertGreater(trace_id_time, one_month_ago_time)

0 comments on commit 7dee36c

Please sign in to comment.