Permalink
Browse files

merge of runtime.

Merge branch 'master' of github.com:nimbusproject/dashi

Conflicts:
	scale_scripts/ping.yml
  • Loading branch information...
2 parents 5586c78 + 3ef0d17 commit 038c30b4af2ec6a6aa3c813201ce8a7c8f38f398 @buzztroll buzztroll committed Dec 20, 2011
Showing with 156 additions and 35 deletions.
  1. +52 −14 dashi/__init__.py
  2. +38 −2 dashi/tests/test_dashi.py
  3. +46 −4 dashi/tests/test_util.py
  4. +17 −13 dashi/util.py
  5. +2 −1 scale_scripts/pingpong.py
  6. +1 −1 scale_scripts/run_pingpong.sh
View
@@ -1,4 +1,5 @@
import socket
+import threading
import traceback
import uuid
import sys
@@ -34,10 +35,21 @@ def __init__(self, name, uri, exchange, durable=False, auto_delete=True, seriali
self._serializer = serializer
- def fire(self, name, operation, **kwargs):
+ @property
+ def name(self):
+ return self._name
+
+ def fire(self, name, operation, args=None, **kwargs):
"""Send a message without waiting for a reply
"""
- d = dict(op=operation, args=kwargs)
+
+ if args:
+ if kwargs:
+ raise TypeError("specify args dict or keyword arguments, not both")
+ else:
+ args = kwargs
+
+ d = dict(op=operation, args=args)
with producers[self._conn].acquire(block=True) as producer:
maybe_declare(self._exchange, producer.channel)
@@ -124,8 +136,13 @@ def handle(self, operation, operation_name=None):
def consume(self, count=None, timeout=None):
self._consumer.consume(count, timeout)
- def cancel(self):
- self._consumer.cancel()
+ def cancel(self, block=True):
+ if self._consumer:
+ self._consumer.cancel(block=block)
+
+ def disconnect(self):
+ if self._consumer:
+ self._consumer.disconnect()
class DashiConsumer(object):
@@ -138,6 +155,7 @@ def __init__(self, dashi, connection, name, exchange):
self._channel = None
self._ops = {}
self._cancelled = False
+ self._consumer_lock = threading.Lock()
self.connect()
@@ -154,16 +172,29 @@ def connect(self):
callbacks=[self._callback])
self._consumer.consume()
+ def disconnect(self):
+ self._consumer.cancel()
+
def consume(self, count=None, timeout=None):
- if count:
- i = 0
- while i < count and not self._cancelled:
- self._consume_one(timeout)
- i += 1
- else:
- while not self._cancelled:
- self._consume_one(timeout)
- self._cancelled = False
+
+ # hold a lock for the duration of the consuming. this prevents
+ # multiple consumers and allows cancel to detect when consuming
+ # has ended.
+ if not self._consumer_lock.acquire(False):
+ raise Exception("only one consumer thread may run concurrently")
+
+ try:
+ if count:
+ i = 0
+ while i < count and not self._cancelled:
+ self._consume_one(timeout)
+ i += 1
+ else:
+ while not self._cancelled:
+ self._consume_one(timeout)
+ finally:
+ self._consumer_lock.release()
+ self._cancelled = False
def _consume_one(self, timeout=None):
@@ -193,8 +224,12 @@ def _consume_one(self, timeout=None):
inner_timeout = timeout - elapsed
- def cancel(self):
+ def cancel(self, block=True):
self._cancelled = True
+ if block:
+ # acquire the lock and release it immediately
+ with self._consumer_lock:
+ pass
def _callback(self, body, message):
reply_to = None
@@ -217,6 +252,9 @@ def _callback(self, body, message):
try:
ret = op_fun(**args)
+ except TypeError, e:
+ log.exception("Type error with handler for %s:%s", self._name, op)
+ raise BadRequestError("Type error: %s" % str(e))
except Exception:
log.exception("Error in handler for %s:%s", self._name, op)
raise
View
@@ -3,14 +3,15 @@
from functools import partial
import itertools
import uuid
+import logging
from kombu.pools import connections
import dashi
import dashi.util
from dashi.tests.util import who_is_calling
-log = dashi.util.get_logger()
+log = logging.getLogger(__name__)
_NO_REPLY = object()
@@ -32,20 +33,31 @@ def __init__(self, **kwargs):
self.reply_with = {}
self.consumer_thread = None
+ self.condition = threading.Condition()
def handle(self, opname, reply_with=_NO_REPLY):
if reply_with is not _NO_REPLY:
self.reply_with[opname] = reply_with
self.conn.handle(partial(self._handler, opname), opname)
def _handler(self, opname, **kwargs):
- self.received.append((opname, kwargs))
+ with self.condition:
+ self.received.append((opname, kwargs))
+ self.condition.notifyAll()
+
if opname in self.reply_with:
reply_with = self.reply_with[opname]
if callable(reply_with):
return reply_with()
return reply_with
+ def wait(self, timeout=5):
+ with self.condition:
+ while not self.received:
+ self.condition.wait(timeout)
+ if not self.received:
+ raise Exception("timed out waiting for message")
+
def consume(self, count):
self.conn.consume(count=count, timeout=self.consume_timeout)
@@ -201,6 +213,30 @@ def test_cancel(self):
# this should hang forever if cancel doesn't work
receiver.join_consumer_thread()
+ def test_cancel_resume_cancel(self):
+ receiver = TestReceiver(uri=self.uri, exchange="x1")
+ receiver.handle("test", 1)
+ receiver.consume_in_thread()
+
+ conn = dashi.DashiConnection("s1", self.uri, "x1")
+ self.assertEqual(1, conn.call(receiver.name, "test"))
+
+ receiver.cancel()
+ receiver.join_consumer_thread()
+ receiver.clear()
+
+ # send message while receiver is cancelled
+ conn.fire(receiver.name, "test", hats=4)
+
+ # start up consumer again. message should arrive.
+ receiver.consume_in_thread()
+
+ receiver.wait()
+ self.assertEqual(receiver.received[-1], ("test", dict(hats=4)))
+
+ receiver.cancel()
+ receiver.join_consumer_thread()
+
class RabbitDashiConnectionTests(DashiConnectionTests):
"""The base dashi tests run on rabbit, plus some extras which are
View
@@ -8,16 +8,28 @@ class LoopingCallTests(unittest.TestCase):
def setUp(self):
self.calls = 0
self.passed = []
+ self.condition = threading.Condition()
self.loop = None
# tests can set this to make looper stop itself after a specified
# number of calls. self.loop must also be set.
self.max_calls = None
+ # tests can set this to make looper raise an exception
+ self.raise_this = None
+
# when looper kills itself, it will set this event
self.stopped = threading.Event()
+ def tearDown(self):
+ if self.loop:
+ # peek into loop and make sure thread is joined
+ self.loop.stop()
+ thread = self.loop.thread
+ if thread:
+ thread.join()
+
def assertPassed(self, index, *args, **kwargs):
passed_args, passed_kwargs = self.passed[index]
self.assertEqual(args, passed_args)
@@ -27,23 +39,33 @@ def assertLastPassed(self, *args, **kwargs):
self.assertPassed(-1, *args, **kwargs)
def looper(self, *args, **kwargs):
- self.calls += 1
- self.passed.append((args, kwargs))
+ with self.condition:
+ self.calls += 1
+ self.passed.append((args, kwargs))
+ self.condition.notifyAll()
if self.max_calls and self.calls >= self.max_calls:
self.loop.stop()
self.stopped.set()
+ if self.raise_this:
+ raise self.raise_this
+
def test_start_stop(self):
- loop = LoopingCall(self.looper, 1, hats=True)
+ self.loop = loop = LoopingCall(self.looper, 1, hats=True)
loop.start(1)
loop.stop()
+
+ with self.condition:
+ if not self.calls:
+ self.condition.wait(5)
+
self.assertEqual(self.calls, 1)
self.assertLastPassed(1, hats=True)
def test_start_stop_2(self):
- loop = LoopingCall(self.looper, 1, hats=True)
+ self.loop = loop = LoopingCall(self.looper, 1, hats=True)
loop.start(1, now=False)
loop.stop()
@@ -58,12 +80,32 @@ def test_called(self):
loop.start(0)
self.assertTrue(self.stopped.wait(5))
+ #peek into looping call and join on thread
+ thread = loop.thread
+ if thread:
+ thread.join()
+
self.assertFalse(loop.running)
self.assertEqual(self.calls, 3)
self.assertPassed(0, 1, 2, anarg=5)
self.assertPassed(1, 1, 2, anarg=5)
self.assertPassed(2, 1, 2, anarg=5)
+ def test_error_caught(self):
+ self.loop = LoopingCall(self.looper)
+ self.raise_this = Exception("too many sandwiches")
+
+ self.loop.start(0)
+
+ with self.condition:
+ while not self.calls >= 3:
+
+ self.condition.wait()
+
+ self.loop.stop()
+ self.assertGreaterEqual(self.calls, 3)
+
+
View
@@ -1,9 +1,6 @@
import logging
import threading
-def get_logger():
- return logging.getLogger('dashi')
-
class LoopingCall(object):
def __init__(self, fun, *args, **kwargs):
@@ -19,8 +16,12 @@ def __init__(self, fun, *args, **kwargs):
self.running = False
+ def __del__(self):
+ self.stop()
+
def start(self, interval, now=True):
assert self.thread is None
+ self.cancelled.clear()
self.running = True
self.thread = threading.Thread(target=self._looper,
@@ -29,22 +30,25 @@ def start(self, interval, now=True):
self.thread.start()
def stop(self):
- self.cancelled.set()
- self.thread = None
+ if self.thread:
+ self.cancelled.set()
def __call__(self):
try:
self.fun(*self.args, **self.kwargs)
except Exception:
- log = get_logger()
+ log = logging.getLogger(__name__)
log.exception("Error in looping call")
def _looper(self, interval, now):
- if now:
- self()
- while not self.cancelled.is_set():
- cancelled = self.cancelled.wait(interval)
- if not cancelled:
+ try:
+ if now:
self()
- self.cancelled.clear()
- self.running = False
+ while not self.cancelled.is_set():
+ cancelled = self.cancelled.wait(interval)
+ if not cancelled:
+ self()
+ finally:
+ self.cancelled.clear()
+ self.thread = None
+ self.running = False
@@ -47,9 +47,10 @@ def go(self):
self.timer.start()
self.start_time = datetime.datetime.now()
print "sending first ping"
+ self.dashi.fire(self.CFG.test.ponger_name, "ping")
while not self.done:
try:
- self.dashi.fire(self.CFG.test.ponger_name, "ping")
+ self.dashi.consume(count=1, timeout=2)
except socket.timeout, ex:
pass
print "sending final message"
@@ -46,7 +46,7 @@ do
ssh $pinger_host $py $r_pgm_file --test.type=ping $cmd_line_args $r_conf_file | tee $out_file
echo "pinger finished, waiting for ponger"
- wait $recv_pid
+ kill $recv_pid
echo "ponger finished"
kill $kill_pid1
kill $kill_pid2

0 comments on commit 038c30b

Please sign in to comment.