Skip to content

Commit

Permalink
Bugfix: Pika basicConsume context propegation
Browse files Browse the repository at this point in the history
Fixing the context propegation to consumer callback.
Bug was fix by attaching and detaching the context before executing the user callback.
Hook location was changed to hook the user callback and not the async enqeueuing of messages.
  • Loading branch information
oxeye-yuval committed Oct 21, 2021
1 parent 3049b4b commit 553a42a
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import wrapt
from pika.adapters import BlockingConnection
from pika.channel import Channel
from pika.adapters.blocking_connection import BlockingChannel

from opentelemetry import trace
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
Expand All @@ -35,18 +35,25 @@
class PikaInstrumentor(BaseInstrumentor): # type: ignore
# pylint: disable=attribute-defined-outside-init
@staticmethod
def _instrument_consumers(
consumers_dict: Dict[str, Callable[..., Any]], tracer: Tracer
def _instrument_blocking_channel_consumers(
channel: BlockingChannel, tracer: Tracer
) -> Any:
for key, callback in consumers_dict.items():
for consumer_tag, consumer_info in channel._consumer_infos.items():
decorated_callback = utils._decorate_callback(
callback, tracer, key
consumer_info.on_message_callback, tracer, consumer_tag
)
setattr(decorated_callback, "_original_callback", callback)
consumers_dict[key] = decorated_callback

setattr(
decorated_callback,
"_original_callback",
consumer_info.on_message_callback,
)
consumer_info.on_message_callback = decorated_callback

@staticmethod
def _instrument_basic_publish(channel: Channel, tracer: Tracer) -> None:
def _instrument_basic_publish(
channel: BlockingChannel, tracer: Tracer
) -> None:
original_function = getattr(channel, "basic_publish")
decorated_function = utils._decorate_basic_publish(
original_function, channel, tracer
Expand All @@ -57,13 +64,13 @@ def _instrument_basic_publish(channel: Channel, tracer: Tracer) -> None:

@staticmethod
def _instrument_channel_functions(
channel: Channel, tracer: Tracer
channel: BlockingChannel, tracer: Tracer
) -> None:
if hasattr(channel, "basic_publish"):
PikaInstrumentor._instrument_basic_publish(channel, tracer)

@staticmethod
def _uninstrument_channel_functions(channel: Channel) -> None:
def _uninstrument_channel_functions(channel: BlockingChannel) -> None:
for function_name in _FUNCTIONS_TO_UNINSTRUMENT:
if not hasattr(channel, function_name):
continue
Expand All @@ -73,8 +80,10 @@ def _uninstrument_channel_functions(channel: Channel) -> None:
unwrap(channel, "basic_consume")

@staticmethod
# Make sure that the spans are created inside hash them set as parent and not as brothers
def instrument_channel(
channel: Channel, tracer_provider: Optional[TracerProvider] = None,
channel: BlockingChannel,
tracer_provider: Optional[TracerProvider] = None,
) -> None:
if not hasattr(channel, "_is_instrumented_by_opentelemetry"):
channel._is_instrumented_by_opentelemetry = False
Expand All @@ -84,18 +93,14 @@ def instrument_channel(
)
return
tracer = trace.get_tracer(__name__, __version__, tracer_provider)
if not hasattr(channel, "_impl"):
_LOG.error("Could not find implementation for provided channel!")
return
if channel._impl._consumers:
PikaInstrumentor._instrument_consumers(
channel._impl._consumers, tracer
)
PikaInstrumentor._instrument_blocking_channel_consumers(
channel, tracer
)
PikaInstrumentor._decorate_basic_consume(channel, tracer)
PikaInstrumentor._instrument_channel_functions(channel, tracer)

@staticmethod
def uninstrument_channel(channel: Channel) -> None:
def uninstrument_channel(channel: BlockingChannel) -> None:
if (
not hasattr(channel, "_is_instrumented_by_opentelemetry")
or not channel._is_instrumented_by_opentelemetry
Expand All @@ -104,12 +109,12 @@ def uninstrument_channel(channel: Channel) -> None:
"Attempting to uninstrument Pika channel while already uninstrumented!"
)
return
if not hasattr(channel, "_impl"):
_LOG.error("Could not find implementation for provided channel!")
return
for key, callback in channel._impl._consumers.items():
if hasattr(callback, "_original_callback"):
channel._impl._consumers[key] = callback._original_callback

for consumers_tag, client_info in channel._consumer_infos.items():
if hasattr(client_info.on_message_callback, "_original_callback"):
channel._consumer_infos[
consumers_tag
] = client_info.on_message_callback._original_callback
PikaInstrumentor._uninstrument_channel_functions(channel)

def _decorate_channel_function(
Expand All @@ -123,28 +128,15 @@ def wrapper(wrapped, instance, args, kwargs):
wrapt.wrap_function_wrapper(BlockingConnection, "channel", wrapper)

@staticmethod
def _decorate_basic_consume(channel, tracer: Optional[Tracer]) -> None:
def _decorate_basic_consume(
channel: BlockingChannel, tracer: Optional[Tracer]
) -> None:
def wrapper(wrapped, instance, args, kwargs):
if not hasattr(channel, "_impl"):
_LOG.error(
"Could not find implementation for provided channel!"
)
return wrapped(*args, **kwargs)
current_keys = set(channel._impl._consumers.keys())
return_value = wrapped(*args, **kwargs)
new_key_list = list(
set(channel._impl._consumers.keys()) - current_keys
)
if not new_key_list:
_LOG.error("Could not find added callback")
return return_value
new_key = new_key_list[0]
callback = channel._impl._consumers[new_key]
decorated_callback = utils._decorate_callback(
callback, tracer, new_key

PikaInstrumentor._instrument_blocking_channel_consumers(
channel, tracer
)
setattr(decorated_callback, "_original_callback", callback)
channel._impl._consumers[new_key] = decorated_callback
return return_value

wrapt.wrap_function_wrapper(channel, "basic_consume", wrapper)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,18 @@ def decorated_callback(
ctx = propagate.extract(properties.headers, getter=_pika_getter)
if not ctx:
ctx = context.get_current()
token = context.attach(ctx)
span = _get_span(
tracer,
channel,
properties,
span_kind=SpanKind.CONSUMER,
task_name=task_name,
ctx=ctx,
operation=MessagingOperationValues.RECEIVE,
)
with trace.use_span(span, end_on_exit=True):
retval = callback(channel, method, properties, body)
context.detach(token)
return retval

return decorated_callback
Expand All @@ -78,14 +79,12 @@ def decorated_function(
properties = BasicProperties(headers={})
if properties.headers is None:
properties.headers = {}
ctx = context.get_current()
span = _get_span(
tracer,
channel,
properties,
span_kind=SpanKind.PRODUCER,
task_name="(temporary)",
ctx=ctx,
operation=None,
)
if not span:
Expand All @@ -109,7 +108,6 @@ def _get_span(
properties: BasicProperties,
task_name: str,
span_kind: SpanKind,
ctx: context.Context,
operation: Optional[MessagingOperationValues] = None,
) -> Optional[Span]:
if context.get_value("suppress_instrumentation") or context.get_value(
Expand All @@ -118,9 +116,7 @@ def _get_span(
return None
task_name = properties.type if properties.type else task_name
span = tracer.start_span(
context=ctx,
name=_generate_span_name(task_name, operation),
kind=span_kind,
name=_generate_span_name("pika", operation), kind=span_kind,
)
if span.is_recording():
_enrich_span(span, channel, properties, task_name, operation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
class TestPika(TestCase):
def setUp(self) -> None:
self.channel = mock.MagicMock(spec=Channel)
self.channel._impl = mock.MagicMock(spec=BaseConnection)
consumer_info = mock.MagicMock()
consumer_info.on_message_callback = mock.MagicMock()
self.channel._consumer_infos = {"consumer-tag": consumer_info}
self.mock_callback = mock.MagicMock()
self.channel._impl._consumers = {"mock_key": self.mock_callback}

def test_instrument_api(self) -> None:
instrumentation = PikaInstrumentor()
Expand All @@ -49,19 +50,19 @@ def test_instrument_api(self) -> None:
"opentelemetry.instrumentation.pika.PikaInstrumentor._decorate_basic_consume"
)
@mock.patch(
"opentelemetry.instrumentation.pika.PikaInstrumentor._instrument_consumers"
"opentelemetry.instrumentation.pika.PikaInstrumentor._instrument_blocking_channel_consumers"
)
def test_instrument(
self,
instrument_consumers: mock.MagicMock,
instrument_blocking_channel_consumers: mock.MagicMock,
instrument_basic_consume: mock.MagicMock,
instrument_channel_functions: mock.MagicMock,
):
PikaInstrumentor.instrument_channel(channel=self.channel)
assert hasattr(
self.channel, "_is_instrumented_by_opentelemetry"
), "channel is not marked as instrumented!"
instrument_consumers.assert_called_once()
instrument_blocking_channel_consumers.assert_called_once()
instrument_basic_consume.assert_called_once()
instrument_channel_functions.assert_called_once()

Expand All @@ -71,18 +72,18 @@ def test_instrument_consumers(
) -> None:
tracer = mock.MagicMock(spec=Tracer)
expected_decoration_calls = [
mock.call(value, tracer, key)
for key, value in self.channel._impl._consumers.items()
mock.call(value.on_message_callback, tracer, key)
for key, value in self.channel._consumer_infos.items()
]
PikaInstrumentor._instrument_consumers(
self.channel._impl._consumers, tracer
PikaInstrumentor._instrument_blocking_channel_consumers(
self.channel, tracer
)
decorate_callback.assert_has_calls(
calls=expected_decoration_calls, any_order=True
)
assert all(
hasattr(callback, "_original_callback")
for callback in self.channel._impl._consumers.values()
for callback in self.channel._consumer_infos.values()
)

@mock.patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,10 @@ def test_get_span(
task_name = "test.test"
span_kind = mock.MagicMock(spec=SpanKind)
get_value.return_value = None
ctx = mock.MagicMock()
_ = utils._get_span(
tracer, channel, properties, task_name, span_kind, ctx
)
_ = utils._get_span(tracer, channel, properties, task_name, span_kind)
generate_span_name.assert_called_once()
tracer.start_span.assert_called_once_with(
context=ctx, name=generate_span_name.return_value, kind=span_kind
name=generate_span_name.return_value, kind=span_kind
)
enrich_span.assert_called_once()

Expand Down Expand Up @@ -200,7 +197,6 @@ def test_decorate_callback(
properties,
span_kind=SpanKind.CONSUMER,
task_name=mock_task_name,
ctx=extract.return_value,
operation=MessagingOperationValues.RECEIVE,
)
use_span.assert_called_once_with(
Expand All @@ -213,12 +209,10 @@ def test_decorate_callback(

@mock.patch("opentelemetry.instrumentation.pika.utils._get_span")
@mock.patch("opentelemetry.propagate.inject")
@mock.patch("opentelemetry.context.get_current")
@mock.patch("opentelemetry.trace.use_span")
def test_decorate_basic_publish(
self,
use_span: mock.MagicMock,
get_current: mock.MagicMock,
inject: mock.MagicMock,
get_span: mock.MagicMock,
) -> None:
Expand All @@ -234,14 +228,12 @@ def test_decorate_basic_publish(
retval = decorated_basic_publish(
channel, method, mock_body, properties
)
get_current.assert_called_once()
get_span.assert_called_once_with(
tracer,
channel,
properties,
span_kind=SpanKind.PRODUCER,
task_name="(temporary)",
ctx=get_current.return_value,
operation=None,
)
use_span.assert_called_once_with(
Expand All @@ -256,14 +248,12 @@ def test_decorate_basic_publish(

@mock.patch("opentelemetry.instrumentation.pika.utils._get_span")
@mock.patch("opentelemetry.propagate.inject")
@mock.patch("opentelemetry.context.get_current")
@mock.patch("opentelemetry.trace.use_span")
@mock.patch("pika.spec.BasicProperties.__new__")
def test_decorate_basic_publish_no_properties(
self,
basic_properties: mock.MagicMock,
use_span: mock.MagicMock,
get_current: mock.MagicMock,
inject: mock.MagicMock,
get_span: mock.MagicMock,
) -> None:
Expand All @@ -277,7 +267,6 @@ def test_decorate_basic_publish_no_properties(
)
retval = decorated_basic_publish(channel, method, body=mock_body)
basic_properties.assert_called_once_with(BasicProperties, headers={})
get_current.assert_called_once()
use_span.assert_called_once_with(
get_span.return_value, end_on_exit=True
)
Expand Down

0 comments on commit 553a42a

Please sign in to comment.