Skip to content

Commit

Permalink
[fix] test expiration header
Browse files Browse the repository at this point in the history
  • Loading branch information
mosquito committed May 17, 2017
1 parent 3c535cd commit d4eba38
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 9 deletions.
21 changes: 12 additions & 9 deletions aio_pika/message.py
@@ -1,4 +1,3 @@
import json
from datetime import datetime, timedelta
from enum import IntEnum, unique
from functools import singledispatch
Expand Down Expand Up @@ -26,7 +25,7 @@ class DeliveryMode(IntEnum):

@singledispatch
def convert_timestamp(value) -> Optional[int]:
raise ValueError('Invalid expiration type: %r' % type(value), value)
raise ValueError('Invalid timestamp type: %r' % type(value), value)


@convert_timestamp.register(datetime)
Expand All @@ -36,17 +35,21 @@ def _convert_datetime(value):


@convert_timestamp.register(int)
def _convert_int(value):
return value


@convert_timestamp.register(float)
def _convert_numbers(value):
return value
return int(value)


@convert_timestamp.register(timedelta)
def _convert_timedelta(value):
return int(value.total_seconds())


@convert_timestamp.register(None)
@convert_timestamp.register(type(None))
def _convert_none(_):
return None

Expand Down Expand Up @@ -97,7 +100,7 @@ def __init__(self, body: bytes, *, headers: dict = None, content_type: str = Non
self.priority = priority
self.correlation_id = self._as_bytes(correlation_id)
self.reply_to = reply_to
self.expiration = convert_timestamp(expiration) * 1000 if expiration else None
self.expiration = expiration
self.message_id = message_id
self.timestamp = convert_timestamp(timestamp)
self.type = type
Expand Down Expand Up @@ -149,7 +152,7 @@ def properties(self) -> BasicProperties:
priority=self.priority,
correlation_id=self.correlation_id,
reply_to=self.reply_to,
expiration=str(int(self.expiration)) if self.expiration else None,
expiration=str(convert_timestamp(self.expiration * 1000)) if self.expiration else None,
message_id=self.message_id,
timestamp=self.timestamp,
type=self.type,
Expand Down Expand Up @@ -236,7 +239,7 @@ def __init__(self, channel: Channel, envelope, properties, body, no_ack: bool =

expiration = None
if properties.expiration:
expiration = self._convert_timestamp(float(properties.expiration) / 1000.)
expiration = convert_timestamp(float(properties.expiration))

super().__init__(
body=body,
Expand All @@ -247,9 +250,9 @@ def __init__(self, channel: Channel, envelope, properties, body, no_ack: bool =
priority=properties.priority,
correlation_id=properties.correlation_id,
reply_to=properties.reply_to,
expiration=expiration / 1000 if expiration else None,
expiration=expiration / 1000. if expiration else None,
message_id=properties.message_id,
timestamp=self._convert_timestamp(float(properties.timestamp)) if properties.timestamp else None,
timestamp=convert_timestamp(float(properties.timestamp)) if properties.timestamp else None,
type=properties.type,
user_id=properties.user_id,
app_id=properties.app_id,
Expand Down
48 changes: 48 additions & 0 deletions tests/test_amqp.py
Expand Up @@ -240,6 +240,8 @@ def test_incoming_message_info(self):

body = bytes(shortuuid.uuid(), 'utf-8')

self.maxDiff = None

info = {
'headers': {"foo": "bar"},
'content_type': "application/json",
Expand Down Expand Up @@ -894,6 +896,52 @@ def bad_handler(message):

yield from wait((client.close(), client.closing), loop=self.loop)

@pytest.mark.asyncio
def test_expiration(self):
client = yield from connect(AMQP_URL, loop=self.loop)

channel = yield from client.channel() # type: aio_pika.Channel

dlx_queue = yield from channel.declare_queue(
self.get_random_name("test_dlx")
) # type: aio_pika.Queue

dlx_exchange = yield from channel.declare_exchange(
self.get_random_name("dlx"),
) # type: aio_pika.Exchange

yield from dlx_queue.bind(dlx_exchange, routing_key=dlx_queue.name)

queue = yield from channel.declare_queue(
self.get_random_name("test_expiration"),
arguments={
"x-message-ttl": 10000,
"x-dead-letter-exchange": dlx_exchange.name,
"x-dead-letter-routing-key": dlx_queue.name,
}
) # type: aio_pika.Queue

body = bytes(shortuuid.uuid(), 'utf-8')

yield from channel.default_exchange.publish(
Message(
body,
content_type='text/plain',
headers={'foo': 'bar'},
expiration=0.5
),
queue.name
)

yield from asyncio.sleep(1, loop=self.loop)

message = yield from dlx_queue.get(timeout=1, no_ack=True)

self.assertEqual(message.body, body)
self.assertEqual(message.headers['x-death'][0]['original-expiration'], '500')

yield from wait((client.close(), client.closing), loop=self.loop)


class MessageTestCase(unittest.TestCase):
def test_message_copy(self):
Expand Down

0 comments on commit d4eba38

Please sign in to comment.