Skip to content

Commit

Permalink
feat: Add an option to bind the queue to custom routing keys
Browse files Browse the repository at this point in the history
  • Loading branch information
jpmckinney committed Jan 17, 2022
1 parent b6ac36a commit 8c6826b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 17 deletions.
5 changes: 5 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ Changelog
0.0.6 (Unreleased)
------------------

Added
~~~~~

- :meth:`yapw.clients.Publisher.declare_queue` and :meth:`yapw.clients.Consumer.consume`: Rename the ``routing_key`` argument to ``queue``, and add a ``routing_keys`` optional argument.

Changed
~~~~~~~

Expand Down
26 changes: 24 additions & 2 deletions tests/clients/test_publisher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from unittest.mock import patch
from unittest.mock import call, patch

import pika
import pytest
Expand Down Expand Up @@ -60,7 +60,29 @@ def test_declare_queue(connection, client_class, durable):
client.declare_queue("q")

client.channel.queue_declare.assert_called_once_with(queue="exch_q", durable=durable)
client.channel.queue_bind.assert_called_once_with(exchange="exch", queue="exch_q", routing_key="exch_q")
assert client.channel.queue_bind.call_count == 1
client.channel.queue_bind.assert_has_calls(
[
call(exchange="exch", queue="exch_q", routing_key="exch_q"),
]
)


@pytest.mark.parametrize("client_class,durable", [(DurableClient, True), (TransientClient, False)])
@patch("pika.BlockingConnection")
def test_declare_queue_routing_keys(connection, client_class, durable):
client = client_class(exchange="exch")

client.declare_queue("q", ["r", "k"])

client.channel.queue_declare.assert_called_once_with(queue="exch_q", durable=durable)
assert client.channel.queue_bind.call_count == 2
client.channel.queue_bind.assert_has_calls(
[
call(exchange="exch", queue="exch_q", routing_key="exch_r"),
call(exchange="exch", queue="exch_q", routing_key="exch_k"),
]
)


@pytest.mark.parametrize("client_class,delivery_mode", [(DurableClient, 2), (TransientClient, 1)])
Expand Down
32 changes: 29 additions & 3 deletions tests/clients/test_threaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def nack_warner(state, channel, method, properties, body):


def writer(state, channel, method, properties, body):
publish(state, channel, {"message": "value"}, "r")
publish(state, channel, {"message": "value"}, "n")
ack(state, channel, method.delivery_tag)


Expand Down Expand Up @@ -232,7 +232,7 @@ def test_publish(message, caplog):
("DEBUG", f"Received message {encode(message)} with routing key yapw_test_q and delivery tag 1"),
(
"DEBUG",
"Published message {'message': 'value'} on channel 1 to exchange yapw_test with routing key yapw_test_r",
"Published message {'message': 'value'} on channel 1 to exchange yapw_test with routing key yapw_test_n",
),
("DEBUG", "Ack'd message on channel 1 with delivery tag 1"),
("INFO", "Received SIGINT, shutting down gracefully"),
Expand All @@ -242,7 +242,7 @@ def test_publish(message, caplog):
def test_consume_declares_queue(caplog):
declarer = get_client()
declarer.connection.call_later(DELAY, functools.partial(kill, signal.SIGINT))
declarer.consume(nack_warner, "q")
declarer.consume(raiser, "q")

publisher = get_client()
publisher.publish({"message": "value"}, "q")
Expand All @@ -259,3 +259,29 @@ def test_consume_declares_queue(caplog):

assert len(caplog.records) > 1
assert all(r.levelname == "WARNING" and r.message == "{'message': 'value'}" for r in caplog.records)


def test_consume_declares_queue_routing_keys(caplog):
declarer = get_client()
declarer.connection.call_later(DELAY, functools.partial(kill, signal.SIGINT))
declarer.consume(raiser, "q", ["r", "k"])

publisher = get_client()
publisher.publish({"message": "r"}, "r")
publisher.publish({"message": "k"}, "k")

consumer = get_client()
consumer.connection.call_later(DELAY, functools.partial(kill, signal.SIGINT))
consumer.consume(ack_warner, "q", ["r", "k"])

publisher.channel.queue_purge("yapw_test_q")
publisher.close()

assert consumer.channel.is_closed
assert consumer.connection.is_closed

assert len(caplog.records) == 2
assert [(r.levelname, r.message) for r in caplog.records] == [
("WARNING", "{'message': 'r'}"),
("WARNING", "{'message': 'k'}"),
]
38 changes: 26 additions & 12 deletions yapw/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,23 @@ def __init__(
if self.exchange:
self.channel.exchange_declare(exchange=self.exchange, exchange_type=exchange_type, durable=self.durable)

def declare_queue(self, routing_key: str) -> None:
def declare_queue(self, queue: str, routing_keys: Optional[List[str]] = None) -> None:
"""
Declare a queue named after the routing key, and bind it to the exchange with the routing key.
Declare a queue, and bind it to the exchange with the routing keys. If no routing keys are provided, the queue
is bound to the exchange using its name as the routing key.
:param routing_key: the routing key
:param queue: the queue's name
:param routing_keys: the queue's routing keys
"""
formatted = self.format_routing_key(routing_key)
if not routing_keys:
routing_keys = [queue]

formatted = self.format_routing_key(queue)
self.channel.queue_declare(queue=formatted, durable=self.durable)
self.channel.queue_bind(exchange=self.exchange, queue=formatted, routing_key=formatted)

for routing_key in routing_keys:
routing_key = self.format_routing_key(routing_key)
self.channel.queue_bind(exchange=self.exchange, queue=formatted, routing_key=routing_key)

def publish(self, message: Any, routing_key: str) -> None:
"""
Expand Down Expand Up @@ -260,7 +267,7 @@ class Threaded:

# Attributes that this mixin expects from base classes.
format_routing_key: Callable[["Threaded", str], str]
declare_queue: Callable[["Threaded", str], None]
declare_queue: Callable[["Threaded", str, Optional[List[str]]], None]
connection: pika.BlockingConnection
channel: pika.adapters.blocking_connection.BlockingChannel

Expand All @@ -281,21 +288,28 @@ def __init__(self, decode: Decode = default_decode, **kwargs: Any):

install_signal_handlers(self._on_shutdown)

def consume(self, callback: ConsumerCallback, routing_key: str, decorator: Decorator = halt) -> None:
def consume(
self,
callback: ConsumerCallback,
queue: str,
routing_keys: Optional[List[str]] = None,
decorator: Decorator = halt,
) -> None:
"""
Declare a queue named after and bound by the routing key, and start consuming messages from that queue.
Declare a queue, bind it to the exchange with the routing keys, and start consuming messages from that queue.
If no routing keys are provided, the queue is bound to the exchange using its name as the routing key.
The consumer callback must be a function that accepts ``(state, channel, method, properties, body)`` arguments,
all but the first of which are the same as Pika's ``basic_consume``. The ``state`` argument is needed to pass
attributes to :mod:`yapw.methods.blocking` functions.
:param callback: the consumer callback
:param routing_key: the routing key
:param queue: the queue's name
:param routing_keys: the queue's routing keys
:param decorator: the decorator of the consumer callback
"""
formatted = self.format_routing_key(routing_key)

self.declare_queue(routing_key)
self.declare_queue(queue, routing_keys)
formatted = self.format_routing_key(queue)

# Don't pass `self` to the callback, to prevent use of unsafe attributes and mutation of safe attributes.
klass = namedtuple("State", self.__getsafe__) # type: ignore # https://github.com/python/mypy/issues/848
Expand Down

0 comments on commit 8c6826b

Please sign in to comment.