Skip to content
Browse files

Cleanups, fix sporadic test failures

  • Loading branch information...
1 parent 99721a3 commit 1de418aff8b736bee70f34415801991751535461 @labisso labisso committed
Showing with 108 additions and 19 deletions.
  1. +7 −0 dashi/__init__.py
  2. +38 −2 dashi/tests/test_dashi.py
  3. +46 −4 dashi/tests/test_util.py
  4. +17 −13 dashi/util.py
View
7 dashi/__init__.py
@@ -140,6 +140,10 @@ def cancel(self, block=True):
if self._consumer:
self._consumer.cancel(block=block)
+ def disconnect(self):
+ if self._consumer:
+ self._consumer.disconnect()
+
class DashiConsumer(object):
def __init__(self, dashi, connection, name, exchange):
@@ -168,6 +172,9 @@ def connect(self):
callbacks=[self._callback])
self._consumer.consume()
+ def disconnect(self):
+ self._consumer.cancel()
+
def consume(self, count=None, timeout=None):
# hold a lock for the duration of the consuming. this prevents
View
40 dashi/tests/test_dashi.py
@@ -3,6 +3,7 @@
from functools import partial
import itertools
import uuid
+import logging
from kombu.pools import connections
@@ -10,7 +11,7 @@
import dashi.util
from dashi.tests.util import who_is_calling
-log = dashi.util.get_logger()
+log = logging.getLogger(__name__)
_NO_REPLY = object()
@@ -32,6 +33,7 @@ def __init__(self, **kwargs):
self.reply_with = {}
self.consumer_thread = None
+ self.condition = threading.Condition()
def handle(self, opname, reply_with=_NO_REPLY):
if reply_with is not _NO_REPLY:
@@ -39,13 +41,23 @@ def handle(self, opname, reply_with=_NO_REPLY):
self.conn.handle(partial(self._handler, opname), opname)
def _handler(self, opname, **kwargs):
- self.received.append((opname, kwargs))
+ with self.condition:
+ self.received.append((opname, kwargs))
+ self.condition.notifyAll()
+
if opname in self.reply_with:
reply_with = self.reply_with[opname]
if callable(reply_with):
return reply_with()
return reply_with
+ def wait(self, timeout=5):
+ with self.condition:
+ while not self.received:
+ self.condition.wait(timeout)
+ if not self.received:
+ raise Exception("timed out waiting for message")
+
def consume(self, count):
self.conn.consume(count=count, timeout=self.consume_timeout)
@@ -201,6 +213,30 @@ def test_cancel(self):
# this should hang forever if cancel doesn't work
receiver.join_consumer_thread()
+ def test_cancel_resume_cancel(self):
+ receiver = TestReceiver(uri=self.uri, exchange="x1")
+ receiver.handle("test", 1)
+ receiver.consume_in_thread()
+
+ conn = dashi.DashiConnection("s1", self.uri, "x1")
+ self.assertEqual(1, conn.call(receiver.name, "test"))
+
+ receiver.cancel()
+ receiver.join_consumer_thread()
+ receiver.clear()
+
+ # send message while receiver is cancelled
+ conn.fire(receiver.name, "test", hats=4)
+
+ # start up consumer again. message should arrive.
+ receiver.consume_in_thread()
+
+ receiver.wait()
+ self.assertEqual(receiver.received[-1], ("test", dict(hats=4)))
+
+ receiver.cancel()
+ receiver.join_consumer_thread()
+
class RabbitDashiConnectionTests(DashiConnectionTests):
"""The base dashi tests run on rabbit, plus some extras which are
View
50 dashi/tests/test_util.py
@@ -8,6 +8,7 @@ class LoopingCallTests(unittest.TestCase):
def setUp(self):
self.calls = 0
self.passed = []
+ self.condition = threading.Condition()
self.loop = None
@@ -15,9 +16,20 @@ def setUp(self):
# number of calls. self.loop must also be set.
self.max_calls = None
+ # tests can set this to make looper raise an exception
+ self.raise_this = None
+
# when looper kills itself, it will set this event
self.stopped = threading.Event()
+ def tearDown(self):
+ if self.loop:
+ # peek into loop and make sure thread is joined
+ self.loop.stop()
+ thread = self.loop.thread
+ if thread:
+ thread.join()
+
def assertPassed(self, index, *args, **kwargs):
passed_args, passed_kwargs = self.passed[index]
self.assertEqual(args, passed_args)
@@ -27,23 +39,33 @@ def assertLastPassed(self, *args, **kwargs):
self.assertPassed(-1, *args, **kwargs)
def looper(self, *args, **kwargs):
- self.calls += 1
- self.passed.append((args, kwargs))
+ with self.condition:
+ self.calls += 1
+ self.passed.append((args, kwargs))
+ self.condition.notifyAll()
if self.max_calls and self.calls >= self.max_calls:
self.loop.stop()
self.stopped.set()
+ if self.raise_this:
+ raise self.raise_this
+
def test_start_stop(self):
- loop = LoopingCall(self.looper, 1, hats=True)
+ self.loop = loop = LoopingCall(self.looper, 1, hats=True)
loop.start(1)
loop.stop()
+
+ with self.condition:
+ if not self.calls:
+ self.condition.wait(5)
+
self.assertEqual(self.calls, 1)
self.assertLastPassed(1, hats=True)
def test_start_stop_2(self):
- loop = LoopingCall(self.looper, 1, hats=True)
+ self.loop = loop = LoopingCall(self.looper, 1, hats=True)
loop.start(1, now=False)
loop.stop()
@@ -58,12 +80,32 @@ def test_called(self):
loop.start(0)
self.assertTrue(self.stopped.wait(5))
+ #peek into looping call and join on thread
+ thread = loop.thread
+ if thread:
+ thread.join()
+
self.assertFalse(loop.running)
self.assertEqual(self.calls, 3)
self.assertPassed(0, 1, 2, anarg=5)
self.assertPassed(1, 1, 2, anarg=5)
self.assertPassed(2, 1, 2, anarg=5)
+ def test_error_caught(self):
+ self.loop = LoopingCall(self.looper)
+ self.raise_this = Exception("too many sandwiches")
+
+ self.loop.start(0)
+
+ with self.condition:
+ while not self.calls >= 3:
+
+ self.condition.wait()
+
+ self.loop.stop()
+ self.assertGreaterEqual(self.calls, 3)
+
+
View
30 dashi/util.py
@@ -1,9 +1,6 @@
import logging
import threading
-def get_logger():
- return logging.getLogger('dashi')
-
class LoopingCall(object):
def __init__(self, fun, *args, **kwargs):
@@ -19,8 +16,12 @@ def __init__(self, fun, *args, **kwargs):
self.running = False
+ def __del__(self):
+ self.stop()
+
def start(self, interval, now=True):
assert self.thread is None
+ self.cancelled.clear()
self.running = True
self.thread = threading.Thread(target=self._looper,
@@ -29,22 +30,25 @@ def start(self, interval, now=True):
self.thread.start()
def stop(self):
- self.cancelled.set()
- self.thread = None
+ if self.thread:
+ self.cancelled.set()
def __call__(self):
try:
self.fun(*self.args, **self.kwargs)
except Exception:
- log = get_logger()
+ log = logging.getLogger(__name__)
log.exception("Error in looping call")
def _looper(self, interval, now):
- if now:
- self()
- while not self.cancelled.is_set():
- cancelled = self.cancelled.wait(interval)
- if not cancelled:
+ try:
+ if now:
self()
- self.cancelled.clear()
- self.running = False
+ while not self.cancelled.is_set():
+ cancelled = self.cancelled.wait(interval)
+ if not cancelled:
+ self()
+ finally:
+ self.cancelled.clear()
+ self.thread = None
+ self.running = False

0 comments on commit 1de418a

Please sign in to comment.
Something went wrong with that request. Please try again.