Skip to content

Commit

Permalink
Merge adc5229 into ec210b6
Browse files Browse the repository at this point in the history
  • Loading branch information
mosquito committed Nov 12, 2017
2 parents ec210b6 + adc5229 commit 269e80f
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 19 deletions.
27 changes: 23 additions & 4 deletions aio_pika/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ class Channel(BaseChannel):

__slots__ = ('_connection', '__closing', '_confirmations', '_delivery_tag',
'loop', '_futures', '_channel', '_on_return_callbacks',
'default_exchange', '_write_lock', '_channel_number', '_publisher_confirms')
'default_exchange', '_write_lock', '_channel_number',
'_publisher_confirms', '_on_return_raises')

def __init__(self, connection, loop: asyncio.AbstractEventLoop,
future_store: FutureStore, channel_number: int=None, publisher_confirms: bool=True):
future_store: FutureStore, channel_number: int=None,
publisher_confirms: bool=True, on_return_raises=False):
"""
:param connection: :class:`aio_pika.adapter.AsyncioConnection` instance
Expand All @@ -50,6 +52,11 @@ def __init__(self, connection, loop: asyncio.AbstractEventLoop,
self._channel_number = channel_number
self._publisher_confirms = publisher_confirms

if not publisher_confirms and on_return_raises:
raise RuntimeError('on_return_raises must be uses with publisher confirms')

self._on_return_raises = on_return_raises

self.default_exchange = self.EXCHANGE_CLASS(
self._channel,
self._publish,
Expand Down Expand Up @@ -133,6 +140,10 @@ def _create_channel(self, timeout=None):
channel = yield from future # type: pika.channel.Channel
if self._publisher_confirms:
channel.confirm_delivery(self._on_delivery_confirmation)

if self._on_return_raises:
channel.add_on_return_callback(self._on_return_delivery)

channel.add_on_close_callback(self._on_channel_close)
channel.add_on_return_callback(self._on_return)

Expand All @@ -146,6 +157,10 @@ def initialize(self, timeout=None) -> None:

self._channel = yield from self._create_channel(timeout)

def _on_return_delivery(self, channel, method_frame, properties, body):
f = self._confirmations.pop(int(properties.headers.get('delivery-tag')))
f.set_exception(exceptions.UnroutableError([body]))

def _on_delivery_confirmation(self, method_frame):
future = self._confirmations.pop(method_frame.method.delivery_tag, None)

Expand Down Expand Up @@ -192,13 +207,18 @@ def declare_exchange(self, name: str, type: ExchangeType = ExchangeType.DIRECT,

@BaseChannel._ensure_channel_is_open
@asyncio.coroutine
def _publish(self, queue_name, routing_key, body, properties, mandatory, immediate):
def _publish(self, queue_name, routing_key, body, properties: pika.BasicProperties, mandatory, immediate):
with (yield from self._write_lock):
while self._connection.is_closed:
log.debug("Can't publish message because connection is inactive")
yield from asyncio.sleep(1, loop=self.loop)

f = self._create_future()
self._delivery_tag += 1

if self._on_return_raises:
properties.headers = properties.headers or {}
properties.headers['delivery-tag'] = str(self._delivery_tag)

try:
self._channel.basic_publish(queue_name, routing_key, body, properties, mandatory, immediate)
Expand All @@ -207,7 +227,6 @@ def _publish(self, queue_name, routing_key, body, properties, mandatory, immedia
self._on_channel_close(self._channel, -1, exc)
self._connection.close(reply_code=500, reply_text="Incorrect state")
else:
self._delivery_tag += 1
if self._publisher_confirms:
self._confirmations[self._delivery_tag] = f
else:
Expand Down
7 changes: 3 additions & 4 deletions aio_pika/common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import asyncio
from contextlib import suppress

import pika.channel
import pika.exceptions
from logging import getLogger
from functools import wraps
from enum import Enum, unique
from .tools import create_future
from . import exceptions


log = getLogger(__name__)
Expand Down Expand Up @@ -111,8 +110,8 @@ def _create_future(self, timeout=None):
def _ensure_channel_is_open(func):
@wraps(func)
def wrap(self, *args, **kwargs):
if self._closing.done():
raise pika.exceptions.ChannelClosed
if self.is_closed:
raise exceptions.ChannelClosed

return func(self, *args, **kwargs)

Expand Down
17 changes: 11 additions & 6 deletions aio_pika/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ def connect(self) -> AsyncioConnection:

@_ensure_connection
@asyncio.coroutine
def channel(self, channel_number: int=None, publisher_confirms: bool=True) -> Generator[Any, None, Channel]:
def channel(self, channel_number: int=None, publisher_confirms: bool=True,
on_return_raises=False) -> Generator[Any, None, Channel]:
""" Coroutine which returns new instance of :class:`Channel`.
Example:
Expand All @@ -237,17 +238,21 @@ async def main(loop):
await channel_no_confirms.close()
:param channel_number: specify the channel number explicit
:param publisher_confirms: if `True` the :method:`aio_pika.Exchange.publish` method will be return
:class:`bool` after publish is complete. Otherwise the :method:`aio_pika.Exchange.publish` method will be
return :class:`None`
:param publisher_confirms:
if `True` the :method:`aio_pika.Exchange.publish` method will be return
:class:`bool` after publish is complete. Otherwise the
:method:`aio_pika.Exchange.publish` method will be return :class:`None`
:param on_return_raises:
raise an :class:`aio_pika.exceptions.UnroutableError`
when mandatory message will be returned
"""
with (yield from self.__write_lock):
log.debug("Creating AMQP channel for conneciton: %r", self)

channel = self.CHANNEL_CLASS(self, self.loop, self.future_store,
channel_number=channel_number,
publisher_confirms=publisher_confirms)
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises)
yield from channel.initialize()

log.debug("Channel created: %r", channel)
Expand Down
4 changes: 3 additions & 1 deletion aio_pika/robust_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class RobustChannel(Channel):
EXCHANGE_CLASS = RobustExchange

def __init__(self, connection, loop: asyncio.AbstractEventLoop,
future_store: FutureStore, channel_number: int=None, publisher_confirms: bool=True):
future_store: FutureStore, channel_number: int=None,
publisher_confirms: bool=True, on_return_raises=False):
"""
:param connection: :class:`aio_pika.adapter.AsyncioConnection` instance
Expand All @@ -40,6 +41,7 @@ def __init__(self, connection, loop: asyncio.AbstractEventLoop,
connection=connection,
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
)

self._closed = False
Expand Down
3 changes: 1 addition & 2 deletions aio_pika/robust_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from logging import getLogger
from typing import Callable, Generator, Any

from pika.exceptions import ProbableAuthenticationError

from .adapter import AsyncioConnection
from .exceptions import ProbableAuthenticationError
from .connection import Connection, connect
from .robust_channel import RobustChannel

Expand Down
2 changes: 1 addition & 1 deletion aio_pika/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

team_email = 'me@mosquito.su'

version_info = (1, 5, 1)
version_info = (1, 6, 0)

__author__ = ", ".join("{} <{}>".format(*info) for info in author_info)
__version__ = ".".join(map(str, version_info))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
install_requires=requires,
extras_require={
'develop': [
'asynctest',
'asynctest<0.11',
'coverage!=4.3',
'coveralls',
'pylama',
Expand Down
19 changes: 19 additions & 0 deletions tests/test_amqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,25 @@ def test_message_nack(self):
yield from queue.delete()
yield from wait((client.close(), client.closing), loop=self.loop)

@pytest.mark.asyncio
def test_on_return_raises(self):
client = yield from self.create_connection()
queue_name = self.get_random_name("test_on_return_raises")
body = uuid.uuid4().bytes

with self.assertRaises(RuntimeError):
yield from client.channel(publisher_confirms=False, on_return_raises=True)

channel = yield from client.channel(publisher_confirms=True, on_return_raises=True)

for _ in range(100):
with self.assertRaises(aio_pika.exceptions.UnroutableError):
yield from channel.default_exchange.publish(
Message(body=body), routing_key=queue_name,
)

yield from client.close()

@asyncio.coroutine
def test_transaction_when_publisher_confirms_error(self):
channel = yield from self.create_channel(publisher_confirms=True)
Expand Down

0 comments on commit 269e80f

Please sign in to comment.