Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Rework connection handling for better resilience

* Include a suite of network failure tests using a TCP proxy
* Run all operations in a retry loop
* Separate consumer and pool connection handling
* Manage producers ourselves
  • Loading branch information...
commit 46ad0e8bd82e1b43286240cde25a47b3de500fd1 1 parent 3b5669e
@labisso labisso authored
View
406 dashi/__init__.py
@@ -4,32 +4,38 @@
import traceback
import uuid
import sys
+import time
import logging
-
+import itertools
from datetime import datetime, timedelta
+
from kombu.connection import Connection
-from kombu.messaging import Consumer
-from kombu.pools import connections, producers
+from kombu.messaging import Consumer, Producer
+from kombu.pools import connections
from kombu.entity import Queue, Exchange
-from kombu.common import maybe_declare
-from exceptions import DashiError, BadRequestError, NotFoundError, UnknownOperationError, WriteConflictError
+from .exceptions import DashiError, BadRequestError, NotFoundError, \
+ UnknownOperationError, WriteConflictError
+from .util import Countdown, RetryBackoff
__version__ = '0.2.7'
log = logging.getLogger(__name__)
-DEFAULT_HEARTBEAT = None # Disabled for now
+DEFAULT_HEARTBEAT = 30
+DEFAULT_QUEUE_EXPIRATION = 60.0
+
-class DashiConnection(object):
+class Dashi(object):
consumer_timeout = 1.0
- #TODO support connection info instead of uri
+ timeout_error = socket.timeout
- def __init__(self, name, uri, exchange, durable=False, auto_delete=True,
- serializer=None, transport_options=None, ssl=False,
- heartbeat=DEFAULT_HEARTBEAT, sysname=None):
+ def __init__(self, name, uri, exchange, durable=False, auto_delete=False,
+ serializer=None, transport_options=None, ssl=False,
+ heartbeat=DEFAULT_HEARTBEAT, sysname=None, retry=None,
+ errback=None):
"""Set up a Dashi connection
@param name: name of destination service queue used by consumers
@@ -43,11 +49,21 @@ def __init__(self, name, uri, exchange, durable=False, auto_delete=True,
@param transport_options: custom parameter dict for the transport backend
@param heartbeat: amqp heartbeat interval
@param sysname: a prefix for exchanges and queues for namespacing
+ @param retry: a RetryBackoff object, or None to use defaults
+ @param errback: callback called within except block of connection failures
"""
self._heartbeat_interval = heartbeat
self._conn = Connection(uri, transport_options=transport_options,
ssl=ssl, heartbeat=self._heartbeat_interval)
+ if heartbeat:
+ # create a connection template for pooled connections. These cannot
+ # have heartbeat enabled.
+ self._pool_conn = Connection(uri, transport_options=transport_options,
+ ssl=ssl)
+ else:
+ self._pool_conn = self._conn
+
self._name = name
self._sysname = sysname
if self._sysname is not None:
@@ -61,13 +77,19 @@ def __init__(self, name, uri, exchange, durable=False, auto_delete=True,
self.durable = durable
self.auto_delete = auto_delete
- self._consumer_conn = None
self._consumer = None
self._linked_exceptions = {}
self._serializer = serializer
+ if retry is None:
+ self.retry = RetryBackoff()
+ else:
+ self.retry = retry
+
+ self._errback = errback
+
@property
def sysname(self):
return self._sysname
@@ -76,12 +98,6 @@ def sysname(self):
def name(self):
return self._name
- def add_sysname(self, name):
- if self.sysname is not None:
- return "%s.%s" % (self.sysname, name)
- else:
- return name
-
def fire(self, name, operation, args=None, **kwargs):
"""Send a message without waiting for a reply
@@ -101,12 +117,20 @@ def fire(self, name, operation, args=None, **kwargs):
d = dict(op=operation, args=args)
headers = {'sender': self.add_sysname(self.name)}
- with producers[self._conn].acquire(block=True) as producer:
- maybe_declare(self._exchange, producer.channel)
- producer.publish(d, routing_key=self.add_sysname(name), exchange=self._exchange_name,
- headers=headers, serializer=self._serializer)
+ dest = self.add_sysname(name)
+
+ def _fire(channel):
+ with Producer(channel) as producer:
+ producer.publish(d, routing_key=dest,
+ headers=headers, serializer=self._serializer,
+ exchange=self._exchange, declare=[self._exchange])
+
+ log.debug("sending message to %s", dest)
+ with connections[self._pool_conn].acquire(block=True) as conn:
+ _, channel = self.ensure(conn, _fire)
+ conn.maybe_close_channel(channel)
- def call(self, name, operation, timeout=5, args=None, **kwargs):
+ def call(self, name, operation, timeout=10, args=None, **kwargs):
"""Send a message and wait for reply
@param name: name of destination service queue
@@ -123,58 +147,125 @@ def call(self, name, operation, timeout=5, args=None, **kwargs):
else:
args = kwargs
- # create a direct exchange and queue for the reply. This may end up
- # being a bottleneck for performance: each rpc call gets a brand new
- # direct exchange and exclusive queue. However this approach is used
- # in nova.rpc and seems to have carried them pretty far. If/when this
+ # create a direct queue for the reply. This may end up being a
+ # bottleneck for performance: each rpc call gets a brand new
+ # exclusive queue. However this approach is used nova.rpc and
+ # seems to have carried them pretty far. If/when this
# becomes a bottleneck we can set up a long-lived backend queue and
# use correlation_id to deal with concurrent RPC calls. See:
# http://www.rabbitmq.com/tutorials/tutorial-six-python.html
msg_id = uuid.uuid4().hex
- exchange = Exchange(name=msg_id, type='direct',
- durable=False, auto_delete=True)
-
- # check out a connection from the pool
- with connections[self._conn].acquire(block=True) as conn:
- queue = Queue(name=msg_id, exchange=exchange, routing_key=msg_id,
- exclusive=True, durable=False, auto_delete=True)
- log.debug("declared call() reply queue %s", msg_id)
-
- messages = []
-
- def _callback(body, message):
- messages.append(body)
- message.ack()
-
- consumer = Consumer(conn, queues=(queue,), callbacks=(_callback,))
- consumer.declare()
-
- d = dict(op=operation, args=args)
- headers = {'reply-to': msg_id, 'sender': self.add_sysname(self.name)}
-
- with producers[self._conn].acquire(block=True) as producer:
- maybe_declare(self._exchange, producer.channel)
- log.debug("sending call to %s:%s", self.add_sysname(name), operation)
- producer.publish(d, routing_key=self.add_sysname(name), headers=headers,
- exchange=self._exchange, serializer=self._serializer)
-
- with consumer:
- log.debug("awaiting call reply on %s", msg_id)
- # only expecting one event
- conn.drain_events(timeout=timeout)
-
- msg_body = messages[0]
- if msg_body.get('error'):
- raise_error(msg_body['error'])
- else:
- return msg_body.get('result')
-
- def reply(self, msg_id, body):
- with producers[self._conn].acquire(block=True) as producer:
+
+ # expire the reply queue shortly after the timeout. it will be
+ # (lazily) deleted by the broker if we don't clean it up first
+ queue_arguments = {'x-expires': int((timeout + 1) * 1000)}
+ queue = Queue(name=msg_id, exchange=self._exchange, routing_key=msg_id,
+ durable=False, queue_arguments=queue_arguments)
+
+ messages = []
+ event = threading.Event()
+
+ def _callback(body, message):
+ messages.append(body)
+ message.ack()
+ log.debug("setting event")
+ event.set()
+
+ d = dict(op=operation, args=args)
+ headers = {'reply-to': msg_id, 'sender': self.add_sysname(self.name)}
+ dest = self.add_sysname(name)
+
+ def _declare_and_send(channel):
+ consumer = Consumer(channel, (queue,), callbacks=(_callback,))
+ with Producer(channel) as producer:
+ producer.publish(d, routing_key=dest, headers=headers,
+ exchange=self._exchange, serializer=self._serializer)
+ return consumer
+
+ log.debug("sending call to %s:%s", dest, operation)
+ with connections[self._pool_conn].acquire(block=True) as conn:
+ consumer, channel = self.ensure(conn, _declare_and_send)
+ try:
+ self._consume(conn, consumer, timeout=timeout, until_event=event)
+ finally:
+ pass
+ conn.maybe_close_channel(channel)
+
+ msg_body = messages[0]
+ if msg_body.get('error'):
+ raise_error(msg_body['error'])
+ else:
+ return msg_body.get('result')
+
+ def _consume(self, connection, consumer, count=None, timeout=None, until_event=None):
+ if count is not None:
+ if count <= 0:
+ raise ValueError("count must be >= 1")
+ consumed = itertools.count(1)
+
+ inner_timeout = self.consumer_timeout
+ if timeout is not None:
+ timeout = Countdown.from_value(timeout)
+ inner_timeout = min(timeout.timeleft, inner_timeout)
+
+ if until_event and until_event.is_set():
+ return
+
+ needs_heartbeat = connection.heartbeat and connection.supports_heartbeats
+ if needs_heartbeat:
+ time_between_tics = timedelta(seconds=connection.heartbeat / 2.0)
+ if self.consumer_timeout > time_between_tics.seconds:
+ msg = "dashi consumer timeout (%s) must be half or smaller than the heartbeat interval %s" % (
+ self.consumer_timeout, connection.heartbeat)
+ raise DashiError(msg)
+ last_heartbeat_check = datetime.min
+
+ reconnect = False
+ declare = True
+ while 1:
try:
- producer.publish(body, routing_key=msg_id, exchange=msg_id, serializer=self._serializer)
- except self._conn.channel_errors:
- log.exception("Failed to reply to msg %s", msg_id)
+ if declare:
+ consumer.consume()
+ declare = False
+
+ if needs_heartbeat:
+ if datetime.now() - last_heartbeat_check > time_between_tics:
+ last_heartbeat_check = datetime.now()
+ connection.heartbeat_check()
+
+ connection.drain_events(timeout=inner_timeout)
+ if count and next(consumed) == count:
+ return
+
+ except socket.timeout:
+ pass
+ except (connection.connection_errors, IOError):
+ log.debug("Received error consuming", exc_info=True)
+ self._call_errback()
+ reconnect = True
+
+ if until_event is not None and until_event.is_set():
+ return
+
+ if timeout:
+ inner_timeout = min(inner_timeout, timeout.timeleft)
+ if not inner_timeout:
+ raise self.timeout_error()
+
+ if reconnect:
+ self.connect(connection, (consumer,), timeout=timeout)
+ reconnect = False
+ declare = True
+
+ def reply(self, connection, msg_id, body):
+ def _reply(channel):
+ with Producer(channel) as producer:
+ producer.publish(body, routing_key=msg_id, exchange=self._exchange,
+ serializer=self._serializer)
+
+ log.debug("replying to %s", msg_id)
+ _, channel = self.ensure(connection, _reply)
+ connection.maybe_close_channel(channel)
def handle(self, operation, operation_name=None, sender_kwarg=None):
"""Handle an operation using the specified function
@@ -184,8 +275,7 @@ def handle(self, operation, operation_name=None, sender_kwarg=None):
@param sender_kwarg: optional keyword arg on operation to feed in sender name
"""
if not self._consumer:
- self._consumer_conn = connections[self._conn].acquire()
- self._consumer = DashiConsumer(self, self._consumer_conn,
+ self._consumer = DashiConsumer(self, self._conn,
self._name, self._exchange, sysname=self._sysname)
self._consumer.add_op(operation_name or operation.__name__, operation,
sender_kwarg=sender_kwarg)
@@ -224,6 +314,80 @@ def link_exceptions(self, custom_exception=None, dashi_exception=None):
self._linked_exceptions[custom_exception] = dashi_exception
+ def _call_errback(self):
+ if not self._errback:
+ return
+ try:
+ self._errback()
+ except Exception:
+ log.exception("error calling errback..")
+
+ def add_sysname(self, name):
+ if self.sysname is not None:
+ return "%s.%s" % (self.sysname, name)
+ else:
+ return name
+
+ def connect(self, connection, entities=None, timeout=None):
+ if timeout is not None:
+ timeout = Countdown.from_value(timeout)
+ backoff = iter(self.retry)
+ while 1:
+
+ this_backoff = next(backoff, False)
+
+ try:
+ channel = self._connect(connection)
+ if entities:
+ for entity in entities:
+ entity.revive(channel)
+ return channel
+
+ except (connection.connection_errors, IOError):
+ if this_backoff is False:
+ log.exception("Error connecting to broker. Giving up.")
+ raise
+ self._call_errback()
+
+ if timeout:
+ timeleft = timeout.timeleft
+ if not timeleft:
+ raise self.timeout_error()
+ elif timeleft < this_backoff:
+ this_backoff = timeleft
+
+ log.exception("Error connecting to broker. Retrying in %ss", this_backoff)
+ time.sleep(this_backoff)
+
+ def _connect(self, connection):
+ # close out previous connection first
+ try:
+ #dirty: breaking into kombu to force close the connection
+ connection._close()
+ except connection.connection_errors:
+ pass
+
+ connection.connect()
+ return connection.channel()
+
+ def ensure(self, connection, func, *args, **kwargs):
+ """Perform an operation until success
+
+ Repeats in the face of connection errors, persuant to retry policy
+ """
+ channel = None
+ while 1:
+ try:
+ if channel is None:
+ channel = connection.channel()
+ return func(channel, *args, **kwargs), channel
+ except (connection.connection_errors, IOError):
+ self._call_errback()
+
+ channel = self.connect(connection)
+
+# alias for compatibility
+DashiConnection = Dashi
_OpSpec = namedtuple('_OpSpec', ['function', 'sender_kwarg'])
@@ -236,35 +400,38 @@ def __init__(self, dashi, connection, name, exchange, sysname=None):
self._exchange = exchange
self._sysname = sysname
- self._channel = None
self._ops = {}
- self._cancelled = False
+ self._cancelled = threading.Event()
self._consumer_lock = threading.Lock()
- self._last_heartbeat_check = datetime.min
+
+ if self._sysname is not None:
+ self._queue_name = "%s.%s" % (self._sysname, self._name)
+ else:
+ self._queue_name = self._name
+
+ self._queue_kwargs = dict(
+ name=self._queue_name,
+ exchange=self._exchange,
+ routing_key=self._queue_name,
+ durable=self._dashi.durable,
+ auto_delete=self._dashi.auto_delete,
+ queue_arguments={'x-expires': int(DEFAULT_QUEUE_EXPIRATION * 1000)})
self.connect()
def connect(self):
+ self._dashi.ensure(self._conn, self._connect)
- self._channel = self._conn.channel()
-
- if self._sysname is not None:
- name = "%s.%s" % (self._sysname, self._name)
- else:
- name = self._name
- self._queue = Queue(channel=self._channel, name=name,
- exchange=self._exchange, routing_key=name,
- durable=self._dashi.durable,
- auto_delete=self._dashi.auto_delete)
+ def _connect(self, channel):
+ self._queue = Queue(channel=channel, **self._queue_kwargs)
self._queue.declare()
- self._consumer = Consumer(self._channel, [self._queue],
- callbacks=[self._callback])
+ self._consumer = Consumer(channel, [self._queue],
+ callbacks=[self._callback])
self._consumer.consume()
def disconnect(self):
self._consumer.cancel()
- self._channel.close()
self._conn.release()
def consume(self, count=None, timeout=None):
@@ -276,67 +443,14 @@ def consume(self, count=None, timeout=None):
raise Exception("only one consumer thread may run concurrently")
try:
- if count:
- i = 0
- while i < count and not self._cancelled:
- self._consume_one(timeout)
- i += 1
- else:
- while not self._cancelled:
- self._consume_one(timeout)
+ self._dashi._consume(self._conn, self._consumer, count=count,
+ timeout=timeout, until_event=self._cancelled)
finally:
self._consumer_lock.release()
- self._cancelled = False
-
- def _consume_one(self, timeout=None):
-
- # do consuming in a busy-ish loop, checking for cancel. There doesn't
- # seem to be an easy way to interrupt drain_events other than the
- # timeout. This could probably be added to kombu if needed. In
- # practice cancellation is likely infrequent (except in tests) so this
- # should hold for now. Can use a long timeout for production and a
- # short one for tests.
-
- inner_timeout = self._dashi.consumer_timeout
- elapsed = 0
-
- # keep trying until a single event is drained or timeout hit
- while not self._cancelled:
-
- self.heartbeat()
-
- try:
- self._conn.drain_events(timeout=inner_timeout)
- break
-
- except socket.timeout:
- if timeout:
- elapsed += inner_timeout
- if elapsed >= timeout:
- raise
-
- if elapsed + inner_timeout > timeout:
- inner_timeout = timeout - elapsed
-
- def heartbeat(self):
- if self._dashi._heartbeat_interval is None:
- return
-
- time_between_tics = timedelta(seconds=self._dashi._heartbeat_interval / 2)
-
- if self._dashi.consumer_timeout > time_between_tics.seconds:
- msg = "dashi consumer timeout (%s) must be half or smaller than the heartbeat interval %s" % (
- self._dashi.consumer_timeout, self._dashi._heartbeat_interval)
-
- raise DashiError(msg)
-
- if datetime.now() - self._last_heartbeat_check > time_between_tics:
- self._last_heartbeat_check = datetime.now()
- self._conn.heartbeat_check()
-
+ self._cancelled.clear()
def cancel(self, block=True):
- self._cancelled = True
+ self._cancelled.set()
if block:
# acquire the lock and release it immediately
with self._consumer_lock:
@@ -405,7 +519,7 @@ def _callback(self, body, message):
traceback=tb)
reply = dict(result=ret, error=err)
- self._dashi.reply(reply_to, reply)
+ self._dashi.reply(self._conn, reply_to, reply)
message.ack()
View
240 dashi/tests/test_dashi.py
@@ -6,8 +6,8 @@
import uuid
import logging
import time
+import sys
-from mock import Mock
from nose.plugins.skip import SkipTest
from kombu.pools import connections
import kombu.pools
@@ -15,13 +15,15 @@
import dashi
import dashi.util
from dashi.exceptions import DashiError
-from dashi.tests.util import who_is_calling
+from dashi.tests.util import who_is_calling, SocatProxy
log = logging.getLogger(__name__)
_NO_EXCEPTION = object()
_NO_REPLY = object()
+retry = dashi.util.RetryBackoff(max_attempts=10, backoff_max=3.0)
+
def assert_kombu_pools_empty():
@@ -52,7 +54,7 @@ def assert_kombu_pools_empty():
class TestReceiver(object):
- consume_timeout = 5
+ consume_timeout = 30
def __init__(self, **kwargs):
@@ -80,6 +82,7 @@ def handle(self, opname, reply_with=_NO_REPLY, raise_exception=_NO_EXCEPTION, **
self.conn.handle(partial(self._handler, opname), opname, **kwargs)
def _handler(self, opname, **kwargs):
+ log.debug("TestReceiver(%s) got op=%s: %s", self.name, opname, kwargs)
with self.condition:
self.received.append((opname, kwargs))
self.condition.notifyAll()
@@ -106,7 +109,7 @@ def wait(self, timeout=5, pred=None):
self.condition.wait(remaining)
now = time.time()
if now - start >= timeout and not pred(self.received):
- raise Exception("timed out waiting for messages")
+ raise Exception("timed out waiting for messages. had: %s", self.received)
remaining -= now - start
def consume(self, count):
@@ -229,8 +232,6 @@ def test_sysname(self):
ret = conn.call(receiver.name, "test", **args1)
self.assertEqual(ret, i)
- time.sleep(10)
-
receiver.join_consumer_thread()
receiver.disconnect()
@@ -400,37 +401,6 @@ def test_handle_sender_kwarg(self):
receiver.disconnect()
assert_kombu_pools_empty()
- def test_heartbeats(self):
-
- receiver = TestReceiver(uri=self.uri, exchange="x1",
- transport_options=self.transport_options, heartbeat=30)
- receiver.conn.consumer_timeout = 100
-
- receiver.handle("test1", "hello", sender_kwarg="sender")
-
- caught_exp = None
- try:
- receiver.consume(1)
- except DashiError, e:
- caught_exp = e
- assert caught_exp
-
- receiver.conn.consumer_timeout = 0.1
- caught_timeout = None
- try:
- receiver.consume(1)
- except socket.timeout, e:
- caught_timeout = e
- assert caught_timeout
-
- receiver.clear()
-
- receiver.cancel()
-
- receiver.disconnect()
- assert_kombu_pools_empty()
-
-
def test_exceptions(self):
class CustomNotFoundError(Exception):
pass
@@ -446,7 +416,7 @@ class CustomNotFoundError(Exception):
args1 = dict(a=1, b="sandwich")
try:
- ret = conn.call(receiver.name, "test_exception", **args1)
+ conn.call(receiver.name, "test_exception", **args1)
except dashi.exceptions.NotFoundError:
pass
else:
@@ -504,7 +474,7 @@ def _thread_erroneous_replies(self, dashiconn, count):
log.exception("Got expected exception replying to a nonexistent exchange")
def test_pool_problems(self):
- raise SkipTest("failing test that exposes problem in dashi RPC strategy")
+ # raise SkipTest("failing test that exposes problem in dashi RPC strategy")
# this test fails (I think) because replies are sent to a nonexistent
# exchange. Rabbit freaks out about this and poisons the channel.
@@ -534,3 +504,195 @@ def test_pool_problems(self):
receiver.wait(pred=pred)
self.assertEqual(len(receiver.received), 100)
+
+
+class RabbitProxyDashiConnectionTests(RabbitDashiConnectionTests):
+ """Test rabbitmq dashi through a TCP proxy that we can kill to simulate failures
+
+ Run all the above rabbit tests too, to make sure proxy behaves ok
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ cls.proxy = SocatProxy("localhost:5672")
+ cls.proxy.start()
+ cls.uri = "amqp://guest:guest@localhost:%s" % cls.proxy.port
+ cls.real_uri = "amqp://guest:guest@localhost:%s" % 5672
+
+ @classmethod
+ def tearDownClass(cls):
+ if cls.proxy:
+ cls.proxy.stop()
+
+ def setUp(self):
+ if not self.proxy.running:
+ self.proxy.start()
+
+ def _make_chained_proxy(self, proxy):
+ second_proxy = SocatProxy(proxy.address, destination_options="ignoreeof")
+ self.addCleanup(second_proxy.stop)
+ second_proxy.start()
+
+ uri = "amqp://guest:guest@localhost:%s" % second_proxy.port
+ return second_proxy, uri
+
+ def test_call_kill_pool_connection(self):
+ # use a pool connection, kill the connection, and then try to reuse it
+
+ # put receiver directly on rabbit. not via proxy
+ receiver = TestReceiver(uri=self.real_uri, exchange="x1",
+ transport_options=self.transport_options, retry=retry)
+ replies = [5, 4, 3, 2, 1]
+ receiver.handle("test", replies.pop)
+ receiver.consume_in_thread()
+
+ conn = dashi.DashiConnection("s1", self.uri, "x1",
+ transport_options=self.transport_options, retry=retry)
+
+ ret = conn.call(receiver.name, "test")
+ self.assertEqual(ret, 1)
+
+ for i in list(reversed(replies)):
+ self.proxy.restart()
+ ret = conn.call(receiver.name, "test")
+ self.assertEqual(ret, i)
+
+ receiver.cancel()
+ receiver.join_consumer_thread()
+ receiver.disconnect()
+
+ assert_kombu_pools_empty()
+
+ def test_call_kill_before_reply(self):
+
+ # have the receiver handler restart the sender's connection
+ # while it is waiting for a reply
+
+ def killit():
+ self.proxy.restart()
+ return True
+
+ # put receiver directly on rabbit. not via proxy
+ receiver = TestReceiver(uri=self.real_uri, exchange="x1",
+ transport_options=self.transport_options, retry=retry)
+ receiver.handle("killme", killit)
+ receiver.consume_in_thread()
+
+ for _ in range(5):
+ conn = dashi.DashiConnection("s1", self.uri, "x1",
+ transport_options=self.transport_options, retry=retry)
+ ret = conn.call(receiver.name, "killme")
+ self.assertEqual(ret, True)
+
+ receiver.cancel()
+ receiver.join_consumer_thread()
+ receiver.disconnect()
+
+ assert_kombu_pools_empty()
+
+ def test_fire_kill_pool_connection(self):
+ # use a pool connection, kill the connection, and then try to reuse it
+
+ # put receiver directly on rabbit. not via proxy
+ receiver = TestReceiver(uri=self.real_uri, exchange="x1",
+ transport_options=self.transport_options, retry=retry)
+ receiver.handle("test")
+ receiver.consume_in_thread()
+
+ conn = dashi.DashiConnection("s1", self.uri, "x1",
+ transport_options=self.transport_options, retry=retry)
+
+ conn.fire(receiver.name, "test", hats=0)
+ receiver.wait(pred=lambda r: len(r) == 1)
+ self.assertEqual(receiver.received[0], ("test", {"hats": 0}))
+
+ for i in range(1, 4):
+ self.proxy.restart()
+ conn.fire(receiver.name, "test", hats=i)
+
+ receiver.wait(pred=lambda r: len(r) == 4)
+ self.assertEqual(receiver.received[1], ("test", {"hats": 1}))
+ self.assertEqual(receiver.received[2], ("test", {"hats": 2}))
+ self.assertEqual(receiver.received[3], ("test", {"hats": 3}))
+
+ receiver.cancel()
+ receiver.join_consumer_thread()
+ receiver.disconnect()
+
+ assert_kombu_pools_empty()
+
+ def test_receiver_kill_connection(self):
+ # restart a consumer's connection. it should reconnect and keep consuming
+ receiver = TestReceiver(uri=self.uri, exchange="x1",
+ transport_options=self.transport_options, retry=retry)
+ receiver.handle("test", "hats")
+ receiver.consume_in_thread()
+
+ # put caller directly on rabbit, not proxy
+ conn = dashi.DashiConnection("s1", self.real_uri, "x1",
+ transport_options=self.transport_options, retry=retry)
+ self.assertEqual(conn.call(receiver.name, "test"), "hats")
+
+ self.proxy.restart()
+
+ self.assertEqual(conn.call(receiver.name, "test"), "hats")
+ self.assertEqual(conn.call(receiver.name, "test"), "hats")
+
+ receiver.cancel()
+ receiver.join_consumer_thread()
+ receiver.disconnect()
+
+ assert_kombu_pools_empty()
+
+ def test_heartbeat_kill(self):
+ # create a second tier proxy. then we can kill the backend proxy
+ # and our connection remains "open"
+ chained_proxy, chained_uri = self._make_chained_proxy(self.proxy)
+
+ event = threading.Event()
+
+ # attach an errback to the receiver that is called by dashi
+ # with any connection failures
+ def errback():
+ log.debug("Errback called", exc_info=True)
+ exc = sys.exc_info()[1]
+ if "Too many heartbeats missed" in str(exc):
+ log.debug("we got the beat!")
+ event.set()
+
+ receiver = TestReceiver(uri=chained_uri, exchange="x1",
+ transport_options=self.transport_options, heartbeat=2.0,
+ errback=errback, retry=retry)
+ receiver.handle("test", "hats")
+ receiver.consume_in_thread()
+
+ # put caller directly on rabbit, not proxy
+ conn = dashi.DashiConnection("s1", self.real_uri, "x1",
+ transport_options=self.transport_options, retry=retry)
+ self.assertEqual(conn.call(receiver.name, "test"), "hats")
+
+ # kill the proxy and wait for the errback from amqp client
+ self.proxy.stop()
+
+ # try a few times to get the heartbeat loss error. depending on
+ # timing, sometimes we just get a connectionloss error
+ for _ in range(4):
+ event.wait(5)
+ if event.is_set():
+ break
+ else:
+ self.proxy.start()
+ time.sleep(3) # give it time to reconnect
+ self.proxy.stop()
+ assert event.is_set()
+
+ # restart and we should be back up and running
+ self.proxy.start()
+ self.assertEqual(conn.call(receiver.name, "test"), "hats")
+
+ receiver.cancel()
+ receiver.join_consumer_thread()
+ receiver.disconnect()
+
+ assert_kombu_pools_empty()
+
View
5 dashi/tests/test_util.py
@@ -60,7 +60,7 @@ def test_start_stop(self):
with self.condition:
if not self.calls:
self.condition.wait(5)
-
+
self.assertEqual(self.calls, 1)
self.assertLastPassed(1, hats=True)
@@ -99,14 +99,13 @@ def test_error_caught(self):
with self.condition:
while not self.calls >= 3:
-
+
self.condition.wait()
self.loop.stop()
self.assertGreaterEqual(self.calls, 3)
-
View
64 dashi/tests/util.py
@@ -1,4 +1,11 @@
import sys
+import subprocess
+import errno
+import os
+import unittest
+import signal
+import socket
+
def who_is_calling():
"""Returns the name of the caller's calling function.
@@ -7,3 +14,60 @@ def who_is_calling():
There must be a better way.
"""
return sys._getframe(2).f_code.co_name
+
+
+class SocatProxy(object):
+ """Manages a TCP forking proxy using socat
+ """
+
+ def __init__(self, destination, source_port=None, source_options=None, destination_options=None):
+ self.port = source_port or free_port()
+ self.address = "localhost:%d" % self.port
+ self.destination = destination
+ self.process = None
+ self.source_options = "," + str(source_options) if source_options else ""
+ self.destination_options = "," + str(destination_options) if destination_options else ""
+
+ def start(self):
+ assert not self.process
+ src_arg = "TCP4-LISTEN:%d,fork,reuseaddr%s" % (self.port, self.source_options)
+ dest_arg = "TCP4:%s%s" % (self.destination, self.destination_options)
+ try:
+ self.process = subprocess.Popen(args=["socat", src_arg, dest_arg],
+ preexec_fn=os.setpgrp)
+ except OSError, e:
+ if e.errno == errno.ENOENT:
+ raise unittest.SkipTest("socat executable not found")
+
+ def stop(self):
+ if self.process and self.process.returncode is None:
+ try:
+ os.killpg(self.process.pid, signal.SIGKILL)
+ except OSError, e:
+ if e.errno != errno.ESRCH:
+ raise
+ self.process.wait()
+ self.process = None
+ return True
+ return False
+
+ def restart(self):
+ self.stop()
+ self.start()
+
+ @property
+ def running(self):
+ return self.process and self.process.returncode is None
+
+
+def free_port(host="localhost"):
+ """Pick a free port on a local interface and return it.
+
+ Races are possible but unlikely
+ """
+ sock = socket.socket()
+ try:
+ sock.bind((host, 0))
+ return sock.getsockname()[1]
+ finally:
+ sock.close()
View
56 dashi/util.py
@@ -1,5 +1,61 @@
import logging
import threading
+import time
+
+
+class Countdown(object):
+ _time_func = time.time
+
+ def __init__(self, timeout, time_func=None):
+ if time_func is not None:
+ self._time_func = time_func
+ self.timeout = timeout
+ self.expires = self._time_func() + timeout
+
+ @classmethod
+ def from_value(cls, timeout):
+ """Wraps a timeout value in a Countdown, unless it already is
+ """
+ if isinstance(timeout, cls):
+ return timeout
+ return cls(timeout)
+
+ @property
+ def expired(self):
+ return self._time_func() >= self.expires
+
+ @property
+ def timeleft(self):
+ """Number of seconds remaining before timeout
+ """
+ return max(0.0, self.expires - self._time_func())
+
+
+class RetryBackoff(object):
+ def __init__(self, max_attempts=0, backoff_start=0.5, backoff_step=0.5, backoff_max=30, timeout=None):
+ self.max_attempts = int(max_attempts)
+ self.backoff_start = float(backoff_start)
+ self.backoff_step = float(backoff_step)
+ self.backoff_max = float(backoff_max)
+
+ self.timeout = Countdown.assure(timeout) if timeout else None
+
+ def __iter__(self):
+ retry = 1
+ backoff = self.backoff_start
+
+ while not self.max_attempts or retry <= self.max_attempts:
+
+ if self.timeout:
+ timeleft = self.timeout.timeleft
+ if not timeleft:
+ return
+ backoff = max(backoff, timeleft)
+
+ yield backoff
+
+ backoff = min(backoff + self.backoff_step, self.backoff_max)
+ retry += 1
class LoopingCall(object):
Please sign in to comment.
Something went wrong with that request. Please try again.