Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: RandomIdGenerator can generate invalid Span/Trace Ids #3949

Merged
merged 9 commits into from
Jun 13, 2024
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- Fix RandomIdGenerator can generate invalid Span/Trace Ids
([#3949](https://github.com/open-telemetry/opentelemetry-python/pull/3949))
- Add Python 3.12 to tox
([#3616](https://github.com/open-telemetry/opentelemetry-python/pull/3616))

Expand Down
12 changes: 10 additions & 2 deletions opentelemetry-sdk/src/opentelemetry/sdk/trace/id_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import abc
import random

from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID

zhihali marked this conversation as resolved.
Show resolved Hide resolved

class IdGenerator(abc.ABC):
@abc.abstractmethod
Expand Down Expand Up @@ -46,7 +48,13 @@ class RandomIdGenerator(IdGenerator):
"""

def generate_span_id(self) -> int:
return random.getrandbits(64)
span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id

def generate_trace_id(self) -> int:
return random.getrandbits(128)
trace_id = random.getrandbits(128)
while trace_id == INVALID_TRACE_ID:
trace_id = random.getrandbits(128)
return trace_id
36 changes: 36 additions & 0 deletions opentelemetry-sdk/tests/trace/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,3 +2061,39 @@ def test_tracer_provider_init_default(self, resource_patch, sample_patch):
sample_patch.assert_called_once()
self.assertIsNotNone(tracer_provider._span_limits)
self.assertIsNotNone(tracer_provider._atexit_handler)


class TestRandomIdGenerator(unittest.TestCase):
_TRACE_ID_MAX_VALUE = 2**128 - 1
_SPAN_ID_MAX_VALUE = 2**64 - 1

@patch(
"random.getrandbits",
side_effect=[trace_api.INVALID_SPAN_ID, 0x00000000DEADBEF0],
)
def test_generate_span_id_avoids_invalid(self, mock_getrandbits):
generator = RandomIdGenerator()
span_id = generator.generate_span_id()

self.assertGreater(span_id, trace_api.INVALID_SPAN_ID)
self.assertLessEqual(span_id, self._SPAN_ID_MAX_VALUE)
self.assertEqual(
mock_getrandbits.call_count, 2
) # Ensure exactly two calls

@patch(
"random.getrandbits",
side_effect=[
trace_api.INVALID_TRACE_ID,
0x000000000000000000000000DEADBEEF,
],
)
def test_generate_trace_id_avoids_invalid(self, mock_getrandbits):
zhihali marked this conversation as resolved.
Show resolved Hide resolved
generator = RandomIdGenerator()
trace_id = generator.generate_trace_id()

self.assertGreater(trace_id, trace_api.INVALID_TRACE_ID)
self.assertLessEqual(trace_id, self._TRACE_ID_MAX_VALUE)
self.assertEqual(
mock_getrandbits.call_count, 2
) # Ensure exactly two calls