Skip to content

Commit

Permalink
Pika - add publish_hook and consume_hook (open-telemetry#763)
Browse files Browse the repository at this point in the history
  • Loading branch information
ItayGibel-helios authored and nicholasgribanov committed Oct 29, 2021
1 parent 9b77f1c commit afc2322
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 17 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `opentelemetry-distro` uses the correct entrypoint name which was updated in the core release of 1.6.0 but the distro was not updated with it
([#755](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/755))

### Added
- `opentelemetry-instrumentation-pika` Add `publish_hook` and `consume_hook` callbacks passed as arguments to the instrument method
([#763](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/763))


## [1.6.1-0.25b1](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.6.1-0.25b1) - 2021-10-18

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@
PikaInstrumentor.instrument_channel(channel, tracer_provider=tracer_provider)
* PikaInstrumentor also supports instrumenting with hooks that will be called when producing or consuming a message.
The hooks should be of type `Callable[[Span, bytes, BasicProperties], None]`
where the first parameter is the span, the second parameter is the message body
and the third parameter is the message properties
.. code-block:: python
def publish_hook(span: Span, body: bytes, properties: BasicProperties):
span.set_attribute("messaging.payload", body.decode())
def consume_hook(span: Span, body: bytes, properties: BasicProperties):
span.set_attribute("messaging.id", properties.message_id)
PikaInstrumentor.instrument_channel(channel, publish_hook=publish_hook, consume_hook=consume_hook)
API
---
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,16 @@ class PikaInstrumentor(BaseInstrumentor): # type: ignore
# pylint: disable=attribute-defined-outside-init
@staticmethod
def _instrument_blocking_channel_consumers(
channel: BlockingChannel, tracer: Tracer
channel: BlockingChannel,
tracer: Tracer,
consume_hook: utils.HookT = utils.dummy_callback,
) -> Any:
for consumer_tag, consumer_info in channel._consumer_infos.items():
decorated_callback = utils._decorate_callback(
consumer_info.on_message_callback, tracer, consumer_tag
consumer_info.on_message_callback,
tracer,
consumer_tag,
consume_hook,
)

setattr(
Expand All @@ -52,22 +57,28 @@ def _instrument_blocking_channel_consumers(

@staticmethod
def _instrument_basic_publish(
channel: BlockingChannel, tracer: Tracer
channel: BlockingChannel,
tracer: Tracer,
publish_hook: utils.HookT = utils.dummy_callback,
) -> None:
original_function = getattr(channel, "basic_publish")
decorated_function = utils._decorate_basic_publish(
original_function, channel, tracer
original_function, channel, tracer, publish_hook
)
setattr(decorated_function, "_original_function", original_function)
channel.__setattr__("basic_publish", decorated_function)
channel.basic_publish = decorated_function

@staticmethod
def _instrument_channel_functions(
channel: BlockingChannel, tracer: Tracer
channel: BlockingChannel,
tracer: Tracer,
publish_hook: utils.HookT = utils.dummy_callback,
) -> None:
if hasattr(channel, "basic_publish"):
PikaInstrumentor._instrument_basic_publish(channel, tracer)
PikaInstrumentor._instrument_basic_publish(
channel, tracer, publish_hook
)

@staticmethod
def _uninstrument_channel_functions(channel: BlockingChannel) -> None:
Expand All @@ -84,6 +95,8 @@ def _uninstrument_channel_functions(channel: BlockingChannel) -> None:
def instrument_channel(
channel: BlockingChannel,
tracer_provider: Optional[TracerProvider] = None,
publish_hook: utils.HookT = utils.dummy_callback,
consume_hook: utils.HookT = utils.dummy_callback,
) -> None:
if not hasattr(channel, "_is_instrumented_by_opentelemetry"):
channel._is_instrumented_by_opentelemetry = False
Expand All @@ -94,10 +107,12 @@ def instrument_channel(
return
tracer = trace.get_tracer(__name__, __version__, tracer_provider)
PikaInstrumentor._instrument_blocking_channel_consumers(
channel, tracer
channel, tracer, consume_hook
)
PikaInstrumentor._decorate_basic_consume(channel, tracer, consume_hook)
PikaInstrumentor._instrument_channel_functions(
channel, tracer, publish_hook
)
PikaInstrumentor._decorate_basic_consume(channel, tracer)
PikaInstrumentor._instrument_channel_functions(channel, tracer)

@staticmethod
def uninstrument_channel(channel: BlockingChannel) -> None:
Expand All @@ -118,33 +133,53 @@ def uninstrument_channel(channel: BlockingChannel) -> None:
PikaInstrumentor._uninstrument_channel_functions(channel)

def _decorate_channel_function(
self, tracer_provider: Optional[TracerProvider]
self,
tracer_provider: Optional[TracerProvider],
publish_hook: utils.HookT = utils.dummy_callback,
consume_hook: utils.HookT = utils.dummy_callback,
) -> None:
def wrapper(wrapped, instance, args, kwargs):
channel = wrapped(*args, **kwargs)
self.instrument_channel(channel, tracer_provider=tracer_provider)
self.instrument_channel(
channel,
tracer_provider=tracer_provider,
publish_hook=publish_hook,
consume_hook=consume_hook,
)
return channel

wrapt.wrap_function_wrapper(BlockingConnection, "channel", wrapper)

@staticmethod
def _decorate_basic_consume(
channel: BlockingChannel, tracer: Optional[Tracer]
channel: BlockingChannel,
tracer: Optional[Tracer],
consume_hook: utils.HookT = utils.dummy_callback,
) -> None:
def wrapper(wrapped, instance, args, kwargs):
return_value = wrapped(*args, **kwargs)

PikaInstrumentor._instrument_blocking_channel_consumers(
channel, tracer
channel, tracer, consume_hook
)
return return_value

wrapt.wrap_function_wrapper(channel, "basic_consume", wrapper)

def _instrument(self, **kwargs: Dict[str, Any]) -> None:
tracer_provider: TracerProvider = kwargs.get("tracer_provider", None)
publish_hook: utils.HookT = kwargs.get(
"publish_hook", utils.dummy_callback
)
consume_hook: utils.HookT = kwargs.get(
"consume_hook", utils.dummy_callback
)

self.__setattr__("__opentelemetry_tracer_provider", tracer_provider)
self._decorate_channel_function(tracer_provider)
self._decorate_channel_function(
tracer_provider,
publish_hook=publish_hook,
consume_hook=consume_hook,
)

def _uninstrument(self, **kwargs: Dict[str, Any]) -> None:
if hasattr(self, "__opentelemetry_tracer_provider"):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from logging import getLogger
from typing import Any, Callable, List, Optional

from pika.channel import Channel
Expand All @@ -13,6 +14,8 @@
from opentelemetry.trace import SpanKind, Tracer
from opentelemetry.trace.span import Span

_LOG = getLogger(__name__)


class _PikaGetter(Getter): # type: ignore
def get(self, carrier: CarrierT, key: str) -> Optional[List[str]]:
Expand All @@ -27,11 +30,18 @@ def keys(self, carrier: CarrierT) -> List[str]:

_pika_getter = _PikaGetter()

HookT = Callable[[Span, bytes, BasicProperties], None]


def dummy_callback(span: Span, body: bytes, properties: BasicProperties):
...


def _decorate_callback(
callback: Callable[[Channel, Basic.Deliver, BasicProperties, bytes], Any],
tracer: Tracer,
task_name: str,
consume_hook: HookT = dummy_callback,
):
def decorated_callback(
channel: Channel,
Expand Down Expand Up @@ -60,6 +70,10 @@ def decorated_callback(
)
try:
with trace.use_span(span, end_on_exit=True):
try:
consume_hook(span, body, properties)
except Exception as hook_exception: # pylint: disable=W0703
_LOG.exception(hook_exception)
retval = callback(channel, method, properties, body)
finally:
context.detach(token)
Expand All @@ -72,6 +86,7 @@ def _decorate_basic_publish(
original_function: Callable[[str, str, bytes, BasicProperties, bool], Any],
channel: Channel,
tracer: Tracer,
publish_hook: HookT = dummy_callback,
):
def decorated_function(
exchange: str,
Expand Down Expand Up @@ -100,6 +115,10 @@ def decorated_function(
with trace.use_span(span, end_on_exit=True):
if span.is_recording():
propagate.inject(properties.headers)
try:
publish_hook(span, body, properties)
except Exception as hook_exception: # pylint: disable=W0703
_LOG.exception(hook_exception)
retval = original_function(
exchange, routing_key, body, properties, mandatory
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from wrapt import BoundFunctionWrapper

from opentelemetry.instrumentation.pika import PikaInstrumentor
from opentelemetry.instrumentation.pika.utils import dummy_callback
from opentelemetry.trace import Tracer


Expand Down Expand Up @@ -72,7 +73,7 @@ def test_instrument_consumers(
) -> None:
tracer = mock.MagicMock(spec=Tracer)
expected_decoration_calls = [
mock.call(value.on_message_callback, tracer, key)
mock.call(value.on_message_callback, tracer, key, dummy_callback)
for key, value in self.channel._consumer_infos.items()
]
PikaInstrumentor._instrument_blocking_channel_consumers(
Expand All @@ -96,7 +97,7 @@ def test_instrument_basic_publish(
original_function = self.channel.basic_publish
PikaInstrumentor._instrument_basic_publish(self.channel, tracer)
decorate_basic_publish.assert_called_once_with(
original_function, self.channel, tracer
original_function, self.channel, tracer, dummy_callback
)
self.assertEqual(
self.channel.basic_publish, decorate_basic_publish.return_value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,52 @@ def test_decorate_callback(
)
self.assertEqual(retval, callback.return_value)

@mock.patch("opentelemetry.instrumentation.pika.utils._get_span")
@mock.patch("opentelemetry.propagate.extract")
@mock.patch("opentelemetry.trace.use_span")
def test_decorate_callback_with_hook(
self,
use_span: mock.MagicMock,
extract: mock.MagicMock,
get_span: mock.MagicMock,
) -> None:
callback = mock.MagicMock()
mock_task_name = "mock_task_name"
tracer = mock.MagicMock()
channel = mock.MagicMock(spec=Channel)
method = mock.MagicMock(spec=Basic.Deliver)
method.exchange = "test_exchange"
properties = mock.MagicMock()
mock_body = b"mock_body"
consume_hook = mock.MagicMock()

decorated_callback = utils._decorate_callback(
callback, tracer, mock_task_name, consume_hook
)
retval = decorated_callback(channel, method, properties, mock_body)
extract.assert_called_once_with(
properties.headers, getter=utils._pika_getter
)
get_span.assert_called_once_with(
tracer,
channel,
properties,
destination=method.exchange,
span_kind=SpanKind.CONSUMER,
task_name=mock_task_name,
operation=MessagingOperationValues.RECEIVE,
)
use_span.assert_called_once_with(
get_span.return_value, end_on_exit=True
)
consume_hook.assert_called_once_with(
get_span.return_value, mock_body, properties
)
callback.assert_called_once_with(
channel, method, properties, mock_body
)
self.assertEqual(retval, callback.return_value)

@mock.patch("opentelemetry.instrumentation.pika.utils._get_span")
@mock.patch("opentelemetry.propagate.inject")
@mock.patch("opentelemetry.trace.use_span")
Expand Down Expand Up @@ -310,3 +356,49 @@ def test_decorate_basic_publish_published_message_to_queue(
task_name="(temporary)",
operation=None,
)

@mock.patch("opentelemetry.instrumentation.pika.utils._get_span")
@mock.patch("opentelemetry.propagate.inject")
@mock.patch("opentelemetry.trace.use_span")
def test_decorate_basic_publish_with_hook(
self,
use_span: mock.MagicMock,
inject: mock.MagicMock,
get_span: mock.MagicMock,
) -> None:
callback = mock.MagicMock()
tracer = mock.MagicMock()
channel = mock.MagicMock(spec=Channel)
exchange_name = "test-exchange"
routing_key = "test-routing-key"
properties = mock.MagicMock()
mock_body = b"mock_body"
publish_hook = mock.MagicMock()

decorated_basic_publish = utils._decorate_basic_publish(
callback, channel, tracer, publish_hook
)
retval = decorated_basic_publish(
exchange_name, routing_key, mock_body, properties
)
get_span.assert_called_once_with(
tracer,
channel,
properties,
destination=exchange_name,
span_kind=SpanKind.PRODUCER,
task_name="(temporary)",
operation=None,
)
use_span.assert_called_once_with(
get_span.return_value, end_on_exit=True
)
get_span.return_value.is_recording.assert_called_once()
inject.assert_called_once_with(properties.headers)
publish_hook.assert_called_once_with(
get_span.return_value, mock_body, properties
)
callback.assert_called_once_with(
exchange_name, routing_key, mock_body, properties, False
)
self.assertEqual(retval, callback.return_value)

0 comments on commit afc2322

Please sign in to comment.