Skip to content

Commit

Permalink
Multi-Callback Support [#48] (#49)
Browse files Browse the repository at this point in the history
* Multi-Callback Support [#48]
* Added sanity check to api connection test and fixing broken test
  • Loading branch information
eandersson committed Jan 5, 2018
1 parent 8c32060 commit 9efeaec
Show file tree
Hide file tree
Showing 25 changed files with 126 additions and 153 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

Version 2.4.0
-------------
- basic.consume now allows for multiple callbacks [#48].

Version 2.3.0
-------------
- Added delivery_tag property to message.
Expand Down
4 changes: 4 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ Additional documentation is available on `amqpstorm.io <https://www.amqpstorm.io
Changelog
=========

Version 2.4.0
-------------
- basic.consume now allows for multiple callbacks [#48].

Version 2.3.0
-------------
- Added delivery_tag property to message.
Expand Down
2 changes: 1 addition & 1 deletion amqpstorm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""AMQPStorm."""
__version__ = '2.3.0' # noqa
__version__ = '2.4.0' # noqa
__author__ = 'eandersson' # noqa

import logging
Expand Down
5 changes: 3 additions & 2 deletions amqpstorm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,12 @@ def consume(self, callback=None, queue='', consumer_tag='',
raise AMQPInvalidArgument('no_local should be a boolean')
elif arguments is not None and not isinstance(arguments, dict):
raise AMQPInvalidArgument('arguments should be a dict or None')
self._channel.consumer_callback = callback
consume_rpc_result = self._consume_rpc_request(arguments, consumer_tag,
exclusive, no_ack,
no_local, queue)
return self._consume_add_and_get_tag(consume_rpc_result)
tag = self._consume_add_and_get_tag(consume_rpc_result)
self._channel._consumer_callbacks[tag] = callback
return tag

def cancel(self, consumer_tag=''):
"""Cancel a queue consumer.
Expand Down
22 changes: 12 additions & 10 deletions amqpstorm/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
class Channel(BaseChannel):
"""RabbitMQ Channel."""
__slots__ = [
'consumer_callback', 'rpc', '_basic', '_confirming_deliveries',
'_consumer_callbacks', 'rpc', '_basic', '_confirming_deliveries',
'_connection', '_exchange', '_inbound', '_queue', '_tx'
]

def __init__(self, channel_id, connection, rpc_timeout):
super(Channel, self).__init__(channel_id)
self.consumer_callback = None
self.rpc = Rpc(self, timeout=rpc_timeout)
self._consumer_callbacks = {}
self._confirming_deliveries = False
self._connection = connection
self._inbound = []
Expand Down Expand Up @@ -259,18 +259,17 @@ def process_data_events(self, to_tuple=False, auto_decode=True):
:return:
"""
if not self.consumer_callback:
raise AMQPChannelError('no consumer_callback defined')
if not self._consumer_callbacks:
raise AMQPChannelError('no consumer callback defined')
for message in self.build_inbound_messages(break_on_empty=True,
to_tuple=to_tuple,
auto_decode=auto_decode):
consumer_tag = message._method.get('consumer_tag')
if to_tuple:
# noinspection PyCallingNonCallable
self.consumer_callback(*message)
self._consumer_callbacks[consumer_tag](*message.to_tuple())
continue
# noinspection PyCallingNonCallable
self.consumer_callback(message)
sleep(IDLE_WAIT)
self._consumer_callbacks[consumer_tag](message)

def rpc_request(self, frame_out, adapter=None):
"""Perform a RPC Request.
Expand All @@ -297,10 +296,13 @@ def start_consuming(self, to_tuple=False, auto_decode=True):
:return:
"""
while not self.is_closed:
self.process_data_events(to_tuple=to_tuple,
auto_decode=auto_decode)
self.process_data_events(
to_tuple=to_tuple,
auto_decode=auto_decode
)
if not self.consumer_tags:
break
sleep(IDLE_WAIT)

def stop_consuming(self):
"""Stop consuming messages.
Expand Down
1 change: 0 additions & 1 deletion amqpstorm/management/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def close(self, connection, reason='Closed via management api'):
:rtype: None
"""

close_payload = json.dumps({
'name': connection,
'reason': reason
Expand Down
52 changes: 36 additions & 16 deletions amqpstorm/tests/functional/generic_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def test_functional_generator_consume(self):

# Store and inbound messages.
inbound_messages = []

for message in self.channel.build_inbound_messages(
break_on_empty=True):
self.assertIsInstance(message, Message)
Expand All @@ -277,29 +278,44 @@ def test_functional_consume_and_redeliver(self):
self.channel.basic.publish(body=self.message,
routing_key=self.queue_name)

def on_message(message):
def on_message_first(message):
self.channel.stop_consuming()
message.reject()

self.channel.basic.consume(callback=on_message,
self.channel.basic.consume(callback=on_message_first,
queue=self.queue_name,
no_ack=False)
self.channel.process_data_events()

# Sleep for 0.01s to make sure RabbitMQ has time to catch up.
time.sleep(0.01)

# Store and inbound messages.
inbound_messages = []

def on_message(message):
# Close current channel and open a new one.
self.channel.close()

# Sleep for 0.1s to make sure RabbitMQ has time to catch up.
time.sleep(0.1)

channel = self.connection.channel()

def on_message_second(message):
inbound_messages.append(message)
self.assertEqual(message.body, self.message)
message.ack()

self.channel.basic.consume(callback=on_message,
queue=self.queue_name,
no_ack=False)
self.channel.process_data_events()
channel.basic.consume(callback=on_message_second,
queue=self.queue_name,
no_ack=True)
channel.process_data_events()

# Sleep for 0.1s to make sure RabbitMQ has time to catch up.
time.sleep(0.1)

start_time = time.time()
while len(inbound_messages) != 1:
if time.time() - start_time >= 30:
break
time.sleep(0.1)

self.assertEqual(len(inbound_messages), 1)

@setup(queue=True)
Expand All @@ -308,7 +324,8 @@ def test_functional_redelivered(self):

self.channel.confirm_deliveries()
self.channel.basic.publish(body=self.message,
routing_key=self.queue_name)
routing_key=self.queue_name,
mandatory=True)

# Sleep for 0.1s to make sure RabbitMQ has time to catch up.
time.sleep(0.1)
Expand All @@ -323,14 +340,17 @@ def on_message(message):
inbound_messages.append(message)
self.assertTrue(message.redelivered)

# Sleep for 0.1s to make sure RabbitMQ has time to catch up.
time.sleep(0.1)

self.channel.basic.consume(callback=on_message,
queue=self.queue_name,
no_ack=True)

self.channel.process_data_events()
start_time = time.time()
while len(inbound_messages) == 0:
self.channel.process_data_events()
if time.time() - start_time >= 30:
break
time.sleep(0.1)

self.assertEqual(len(inbound_messages), 1)

@setup(queue=True)
Expand Down
5 changes: 4 additions & 1 deletion amqpstorm/tests/functional/management/connection_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ class ApiConnectionFunctionalTests(TestFunctionalFramework):
def test_api_connection_get(self):
api = ManagementApi(HTTP_URL, USERNAME, PASSWORD)

for conn in api.connection.list():
connections = api.connection.list()
self.assertIsNotNone(connections)

for conn in connections:
self.assertIsInstance(api.connection.get(conn['name']), dict)

@setup()
Expand Down
5 changes: 3 additions & 2 deletions amqpstorm/tests/functional/reliability_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ def test_functional_publish_and_consume_5k_messages(self):
consumer_thread.start()

start_time = time.time()
while (self.messages_consumed != self.messages_to_send and
time.time() - start_time < 60):
while self.messages_consumed != self.messages_to_send:
if time.time() - start_time >= 60:
break
time.sleep(0.1)

self.assertEqual(self.messages_consumed, self.messages_to_send,
Expand Down
15 changes: 11 additions & 4 deletions amqpstorm/tests/functional/web_based_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,29 @@ def test_functional_remove_queue_while_consuming(self):

self.assertFalse(self.channel._inbound)

@setup()
@setup(queue=True)
def test_functional_connection_forcefully_closed(self):
self.channel.confirm_deliveries()
self.channel.queue.declare(self.queue_name)

connection_list = retry_function_wrapper(self.api.connection.list)
self.assertIsNotNone(connection_list)

for connection in connection_list:
self.api.connection.close(connection['name'])

# Sleep for 1s to make sure RabbitMQ has time to catch up.
time.sleep(1)
start_time = time.time()
while len(self.api.connection.list()) > 0:
if time.time() - start_time >= 60:
break
time.sleep(1)

self.assertRaisesRegexp(
AMQPConnectionError,
'Connection was closed by remote server: '
'CONNECTION_FORCED - Closed via management api',
self.channel.basic.publish, 'body', 'routing_key'
self.channel.basic.publish, 'body', self.queue_name, '',
None, True, False
)

self.assertRaisesRegexp(
Expand Down
2 changes: 1 addition & 1 deletion amqpstorm/tests/unit/channel/channel_exception_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_chanel_callback_not_set(self):

self.assertRaisesRegexp(
AMQPChannelError,
'no consumer_callback defined',
'no consumer callback defined',
channel.process_data_events
)

Expand Down
36 changes: 24 additions & 12 deletions amqpstorm/tests/unit/channel/channel_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_channel_process_data_events(self):
message = self.message.encode('utf-8')
message_len = len(message)

deliver = specification.Basic.Deliver()
deliver = specification.Basic.Deliver(consumer_tag='travis-ci')
header = ContentHeader(body_size=message_len)
body = ContentBody(value=message)

Expand All @@ -122,7 +122,7 @@ def callback(msg):
self.assertIsInstance(msg.body, str)
self.assertEqual(msg.body.encode('utf-8'), message)

channel.consumer_callback = callback
channel._consumer_callbacks['travis-ci'] = callback
channel.process_data_events()

def test_channel_process_data_events_as_tuple(self):
Expand All @@ -132,7 +132,7 @@ def test_channel_process_data_events_as_tuple(self):
message = self.message.encode('utf-8')
message_len = len(message)

deliver = specification.Basic.Deliver()
deliver = specification.Basic.Deliver(consumer_tag='travis-ci')
header = ContentHeader(body_size=message_len)
body = ContentBody(value=message)

Expand All @@ -145,7 +145,7 @@ def callback(body, channel, method, properties):
self.assertIsInstance(properties, dict)
self.assertEqual(body, message)

channel.consumer_callback = callback
channel._consumer_callbacks['travis-ci'] = callback
channel.process_data_events(to_tuple=True)

def test_channel_start_consuming(self):
Expand All @@ -155,7 +155,7 @@ def test_channel_start_consuming(self):
message = self.message.encode('utf-8')
message_len = len(message)

deliver = specification.Basic.Deliver()
deliver = specification.Basic.Deliver(consumer_tag='travis-ci')
header = ContentHeader(body_size=message_len)
body = ContentBody(value=message)

Expand All @@ -166,28 +166,40 @@ def callback(msg):
self.assertEqual(msg.body.encode('utf-8'), message)
channel.set_state(channel.CLOSED)

channel.consumer_callback = callback
channel.add_consumer_tag('travis-ci')
channel._consumer_callbacks['travis-ci'] = callback
channel.start_consuming()

def test_channel_start_consuming_no_consumer_tag(self):
def test_channel_start_consuming_multiple_callbacks(self):
channel = Channel(0, FakeConnection(), 360)
channel.set_state(channel.OPEN)

message = self.message.encode('utf-8')
message_len = len(message)

deliver = specification.Basic.Deliver()
deliver_one = specification.Basic.Deliver(consumer_tag='travis-ci-1')
deliver_two = specification.Basic.Deliver(consumer_tag='travis-ci-2')
header = ContentHeader(body_size=message_len)
body = ContentBody(value=message)

channel._inbound = [deliver, header, body]
channel._inbound = [
deliver_one, header, body,
deliver_two, header, body
]

def callback(msg):
def callback_one(msg):
self.assertEqual(msg.method.get('consumer_tag'), 'travis-ci-1')
self.assertIsInstance(msg.body, str)
self.assertEqual(msg.body.encode('utf-8'), message)

channel.consumer_callback = callback
def callback_two(msg):
self.assertEqual(msg.method.get('consumer_tag'), 'travis-ci-2')
self.assertIsInstance(msg.body, str)
self.assertEqual(msg.body.encode('utf-8'), message)
channel.set_state(channel.CLOSED)

channel._consumer_callbacks['travis-ci-1'] = callback_one
channel._consumer_callbacks['travis-ci-2'] = callback_two

channel.start_consuming()

def test_channel_open(self):
Expand Down
Loading

0 comments on commit 9efeaec

Please sign in to comment.