Skip to content

Commit

Permalink
187 B3 Headers Support
Browse files Browse the repository at this point in the history
Add support for B3 trace headers whil continuing to support the old
trace headers. This change aims to allow both both B3 and the old
headers to be used as long as they contain identical values. For
example:

```
X-Trace: 1234567890
X-B3-TraceId: 1234567890
```

Fixes reddit#187
  • Loading branch information
cshoe committed Aug 7, 2018
1 parent d0593ed commit eb71fff
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 29 deletions.
57 changes: 57 additions & 0 deletions baseplate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class TraceInfo(_TraceInfo):
collecting the trace context and passing it along to the server span.
"""

@classmethod
def new(cls):
"""Generate IDs for a new initial server span.
Expand Down Expand Up @@ -135,6 +136,62 @@ def from_upstream(cls, trace_id, parent_id, span_id, sampled, flags):

return cls(trace_id, parent_id, span_id, sampled, flags)

@classmethod
def extract_upstream_header_values(cls, upstream_header_names, headers):
"""Extract values from upstream headers.
This method thinks about upstream headers by a general name as oppposed to the header
name, i.e. "trace_id" instead of "X-Trace". These general names are "trace_id",
"span_id", "parent_span_id", "sampled" and "flags".
A dict mapping these general names to corresponding header names is expected.
For example:
{
"trace_id": ("X-Trace", "X-B3-TraceId"),
"span_id": ("X-Span", "X-B3-SpanId"),
"parent_span_id": ("X-Parent", "X-B3-ParentSpanId"),
"sampled": ("X-Sampled", "X-B3-Sampled"),
"flags": ("X-Flags", "X-B3-Flags"),
}
This structure is used to extract relevant values from the request headers resulting
in a dict mapping general names to values.
For example:
{
"trace_id": "2391921232992245445",
"span_id": "7638783876913511395",
"parent_span_id": "3383915029748331832",
"sampled": "1",
}
:param dict upstream_headers_name: Map of general upstream value labels to header names
:param dict headers: Headers sent with a request
:return: Values found in upstream trace headers
:rtype: dict
:raises: :py:exc:`ValueError` if conflicting values are found for the same header category
"""
extracted_values = {}
for name, header_names in upstream_header_names.items():
values = []
for header_name in header_names:
if header_name in headers:
values.append(headers[header_name])

if not values:
continue
elif not all(value == values[0] for value in values):
raise ValueError("Conflicting values found for %s header(s)".format(header_names))
else:
# All the values are the same
extracted_values[name] = values[0]
return extracted_values


class AuthenticationTokenValidator(object):
"""Factory that knows how to validate raw authentication tokens."""
Expand Down
35 changes: 21 additions & 14 deletions baseplate/integration/pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def make_app(app_config):
from ..server import make_app


TRACE_HEADER_NAMES = {
"trace_id": ("X-Trace", "X-B3-TraceId"),
"span_id": ("X-Span", "X-B3-SpanId"),
"parent_span_id": ("X-Parent", "X-B3-ParentSpanId"),
"sampled": ("X-Sampled", "X-B3-Sampled"),
"flags": ("X-Flags", "X-B3-Flags"),
}


def _make_baseplate_tween(handler, registry):
def baseplate_tween(request):
try:
Expand Down Expand Up @@ -116,20 +125,7 @@ def _on_new_request(self, event):
trace_info = None
if self.trust_trace_headers:
try:
sampled = request.headers.get("X-Sampled", None)
if sampled is not None:
sampled = True if sampled == "1" else False
flags = request.headers.get("X-Flags", None)
if flags is not None:
flags = int(flags)
trace_info = TraceInfo.from_upstream(
trace_id=int(request.headers["X-Trace"]),
parent_id=int(request.headers["X-Parent"]),
span_id=int(request.headers["X-Span"]),
sampled=sampled,
flags=flags,
)

trace_info = self._get_trace_info(request.headers)
edge_payload = request.headers.get("X-Edge-Request", None)
if self.edge_context_factory:
edge_context = self.edge_context_factory.from_upstream(
Expand All @@ -156,6 +152,17 @@ def _start_server_span(self, request, name, trace_info=None):
request.trace.start()
request.registry.notify(ServerSpanInitialized(request))

def _get_trace_info(self, headers):
extracted_values = TraceInfo.extract_upstream_header_values(TRACE_HEADER_NAMES, headers)
flags = extracted_values.get("flags", None)
return TraceInfo.from_upstream(
int(extracted_values["trace_id"]),
int(extracted_values["parent_span_id"]),
int(extracted_values["span_id"]),
True if extracted_values["sampled"] == "1" else False,
int(flags) if flags is not None else None,
)

def includeme(self, config):
config.add_subscriber(self._on_new_request, pyramid.events.ContextFound)

Expand Down
36 changes: 21 additions & 15 deletions baseplate/integration/thrift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ def make_processor(app_config):
from ...core import TraceInfo


TRACE_HEADER_NAMES = {
"trace_id": (b"Trace", b"B3-TraceId"),
"span_id": (b"Span", b"B3-SpanId"),
"parent_span_id": (b"Parent", b"B3-ParentSpanId"),
"sampled": (b"Sampled", b"B3-Sampled"),
"flags": (b"Flags", b"B3-Flags"),
}


class RequestContext(object):
pass

Expand Down Expand Up @@ -56,21 +65,7 @@ def getHandlerContext(self, fn_name, server_context):
trace_info = None
headers = server_context.iprot.trans.get_headers()
try:
sampled = headers.get(b"Sampled", None)
if sampled is not None:
sampled = True if sampled.decode('utf-8') == "1" else False
flags = headers.get(b"Flags", None)
if flags is not None:
flags = int(flags)

trace_info = TraceInfo.from_upstream(
trace_id=int(headers[b"Trace"]),
parent_id=int(headers[b"Parent"]),
span_id=int(headers[b"Span"]),
sampled=sampled,
flags=flags,
)

trace_info = self._get_trace_info(headers)
edge_payload = headers.get(b"Edge-Request", None)
if self.edge_context_factory:
edge_context = self.edge_context_factory.from_upstream(
Expand Down Expand Up @@ -115,3 +110,14 @@ def handlerError(self, handler_context, fn_name, exception):
handler_context.trace.finish(exc_info=sys.exc_info())
handler_context.trace.is_finished = True
self.logger.exception("Unexpected exception in %r.", fn_name)

def _get_trace_info(self, headers):
extracted_values = TraceInfo.extract_upstream_header_values(TRACE_HEADER_NAMES, headers)
flags = extracted_values.get("flags", None)
return TraceInfo.from_upstream(
int(extracted_values["trace_id"]),
int(extracted_values["parent_span_id"]),
int(extracted_values["span_id"]),
True if extracted_values["sampled"].decode("utf-8") == "1" else False,
int(flags) if flags is not None else None,
)
25 changes: 25 additions & 0 deletions tests/integration/pyramid_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,31 @@ def test_trace_headers(self):
self.assertTrue(self.server_observer.on_finish.called)
self.assertTrue(self.context_init_event_subscriber.called)

def test_b3_trace_headers(self):
self.test_app.get("/example", headers={
"X-B3-TraceId": "1234",
"X-B3-ParentSpanId": "2345",
"X-B3-SpanId": "3456",
"X-B3-Sampled": "1",
"X-B3-Flags": "1",
})

self.assertEqual(self.observer.on_server_span_created.call_count, 1)

context, server_span = self.observer.on_server_span_created.call_args[0]
self.assertEqual(server_span.trace_id, 1234)
self.assertEqual(server_span.parent_id, 2345)
self.assertEqual(server_span.id, 3456)
self.assertEqual(server_span.sampled, True)
self.assertEqual(server_span.flags, 1)

with self.assertRaises(NoAuthenticationError):
context.request_context.user.id

self.assertTrue(self.server_observer.on_start.called)
self.assertTrue(self.server_observer.on_finish.called)
self.assertTrue(self.context_init_event_subscriber.called)

def test_edge_request_headers(self):
self.test_app.get("/example", headers={
"X-Trace": "1234",
Expand Down
33 changes: 33 additions & 0 deletions tests/integration/thrift_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,39 @@ def test_with_headers(self):
self.assertEqual(self.server_observer.on_finish.call_count, 1)
self.assertEqual(self.server_observer.on_finish.call_args[0], (None,))

def test_b3_trace_headers(self):
client_memory_trans = TMemoryBuffer()
client_prot = THeaderProtocol(client_memory_trans)
client_header_trans = client_prot.trans
client_header_trans.set_header("B3-TraceId", "1234")
client_header_trans.set_header("B3-ParentSpanId", "2345")
client_header_trans.set_header("B3-SpanId", "3456")
client_header_trans.set_header("B3-Sampled", "1")
client_header_trans.set_header("B3-Flags", "1")
client = TestService.Client(client_prot)
try:
client.example_simple()
except TTransportException:
pass # we don't have a test response for the client
self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

self.processor.process(self.iprot, self.oprot, self.server_context)
self.assertEqual(self.observer.on_server_span_created.call_count, 1)

context, server_span = self.observer.on_server_span_created.call_args[0]
self.assertEqual(server_span.trace_id, 1234)
self.assertEqual(server_span.parent_id, 2345)
self.assertEqual(server_span.id, 3456)
self.assertTrue(server_span.sampled)
self.assertEqual(server_span.flags, 1)

with self.assertRaises(NoAuthenticationError):
context.request_context.user.id

self.assertEqual(self.server_observer.on_start.call_count, 1)
self.assertEqual(self.server_observer.on_finish.call_count, 1)
self.assertEqual(self.server_observer.on_finish.call_args[0], (None,))

def test_edge_request_headers(self):
client_memory_trans = TMemoryBuffer()
client_prot = THeaderProtocol(client_memory_trans)
Expand Down

0 comments on commit eb71fff

Please sign in to comment.