Permalink
Browse files

More ZooLock features and better tests

  • Loading branch information...
1 parent a28b01a commit af013384740e77d8d0cc847129ce9eed6af086c2 @labisso labisso committed Feb 2, 2012
Showing with 143 additions and 30 deletions.
  1. +2 −2 kazoo/__init__.py
  2. +21 −2 kazoo/recipe/lock.py
  3. +97 −20 kazoo/recipe/test/test_lock.py
  4. +1 −1 kazoo/sync/test/test_sync.py
  5. +22 −5 kazoo/test/__init__.py
View
@@ -12,7 +12,7 @@ def disable_zookeeper_log():
import zookeeper
zookeeper.set_log_stream(open('/dev/null'))
-if not "AZK_LOG_ENABLED" in os.environ:
+if not "KAZOO_LOG_ENABLED" in os.environ:
disable_zookeeper_log()
def patch_extras():
@@ -23,6 +23,6 @@ def patch_extras():
import threading
threading._sleep = sleep
-if "AZK_TEST_GEVENT_PATCH" in os.environ:
+if "KAZOO_TEST_GEVENT_PATCH" in os.environ:
from gevent import monkey; monkey.patch_all()
patch_extras()
View
@@ -2,18 +2,23 @@
import uuid
from kazoo.retry import ForceRetryError
+from zookeeper import NoNodeException
#noinspection PyArgumentList
class ZooLock(object):
_LOCK_NAME = '_lock_'
- def __init__(self, client, path):
+ def __init__(self, client, path, contender_name=None):
"""
@type client ZooKeeperClient
"""
self.client = client
self.path = path
+ # some data is written to the node. this can be queries via
+ # get_contenders() to see who is contending for the lock
+ self.data = str(contender_name or "")
+
self.condition = threading.Condition()
# props to Netflix Curator for this trick. It is possible for our
@@ -49,7 +54,7 @@ def _inner_acquire(self):
self.create_tried = True
if not node:
- node = self.client.create(self.create_path, "",
+ node = self.client.create(self.create_path, self.data,
ephemeral=True, sequence=True)
# strip off path to node
node = node[len(self.path)+1:]
@@ -123,6 +128,20 @@ def _inner_release(self):
return True
+ def get_contenders(self):
+ """Return an ordered list of the current contenders for the lock
+ """
+ children = self._get_sorted_children()
+
+ contenders = []
+ for child in children:
+ try:
+ data, stat = self.client.get(self.path + "/" + child)
+ contenders.append(data)
+ except NoNodeException:
+ pass
+ return contenders
+
def __enter__(self):
self.acquire()
@@ -4,7 +4,7 @@
import time
from kazoo.recipe.lock import ZooLock
-from kazoo.test import get_client_or_skip
+from kazoo.test import get_client_or_skip, until_timeout
class ZooLockTests(unittest.TestCase):
@@ -14,7 +14,8 @@ def setUp(self):
self.lockpath = "/" + uuid.uuid4().hex
self._c.create(self.lockpath, "")
- self.active = 0
+ self.condition = threading.Condition()
+ self.active_thread = None
def tearDown(self):
if self.lockpath:
@@ -23,37 +24,113 @@ def tearDown(self):
except Exception:
pass
+ def test_lock_one(self):
+ c = get_client_or_skip()
+ c.connect()
+
+ contender_name = uuid.uuid4().hex
+ lock = ZooLock(c, self.lockpath, contender_name)
+
+ event = threading.Event()
+
+ thread = threading.Thread(target=self._thread_lock_acquire_til_event,
+ args=(contender_name, lock, event))
+ thread.start()
+
+ anotherlock = ZooLock(c, self.lockpath, contender_name)
+ contenders = None
+ for _ in until_timeout(5):
+ contenders = anotherlock.get_contenders()
+ if contenders:
+ break
+ time.sleep(0)
+
+ self.assertEqual(contenders, [contender_name])
+
+ with self.condition:
+ while self.active_thread != contender_name:
+ self.condition.wait()
+
+ # release the lock
+ event.set()
+
+ with self.condition:
+ while self.active_thread:
+ self.condition.wait()
+
+
def test_lock(self):
- clients = []
- locks = []
+ threads = []
+ names = ["contender"+str(i) for i in range(5)]
+
+ contender_bits = {}
- for _ in range(5):
+ for name in names:
c = get_client_or_skip()
c.connect()
- l = ZooLock(c, self.lockpath)
+ e = threading.Event()
- clients.append(c)
- locks.append(l)
+ l = ZooLock(c, self.lockpath, name)
+ t = threading.Thread(target=self._thread_lock_acquire_til_event,
+ args=(name, l, e))
+ contender_bits[name] = (t, e)
+ threads.append(t)
- # these will be greenlets in a monkey patched test env.
- threads = [threading.Thread(target=self._thread_lock_acquire,
- args=(lock,)) for lock in locks]
+ # acquire the lock ourselves first to make the others line up
+ lock = ZooLock(self._c, self.lockpath, "test")
+ lock.acquire()
for t in threads:
t.start()
- for t in threads:
- t.join()
+ contenders = None
+ # wait for everyone to line up on the lock
+ for _ in until_timeout(5):
+ contenders = lock.get_contenders()
+ if len(contenders) == 6:
+ break
+
+ self.assertEqual(contenders[0], "test")
+ contenders = contenders[1:]
+ remaining = list(contenders)
+
+ # release the lock and contenders should claim it in order
+ lock.release()
+
+ for contender in contenders:
+ thread, event = contender_bits[contender]
+
+ with self.condition:
+ while not self.active_thread:
+ self.condition.wait()
+ self.assertEqual(self.active_thread, contender)
+
+ self.assertEqual(lock.get_contenders(), remaining)
+ remaining = remaining[1:]
- self.assertEqual(0, self.active)
+ event.set()
- def _thread_lock_acquire(self, lock):
+ with self.condition:
+ while self.active_thread:
+ self.condition.wait()
+ thread.join()
+
+
+ def _thread_lock_acquire_til_event(self, name, lock, event):
with lock:
- self.active += 1
- self.assertEqual(self.active, 1)
- print "got lock"
- time.sleep(0)
- self.active -= 1
+ #print "%s enter lock" % name
+ with self.condition:
+ self.assertIsNone(self.active_thread)
+ self.active_thread = name
+ self.condition.notify_all()
+
+ event.wait()
+
+ with self.condition:
+ self.assertEqual(self.active_thread, name)
+ self.active_thread = None
+ self.condition.notify_all()
+ #print "%s exit lock" % name
@@ -70,7 +70,7 @@ def fun(i):
realthread.start_new_thread(thread_dispatch_callbacks,
(self.sync, fun, callbacks))
- done.wait()
+ done.wait(10)
self.assertEqual(results, range(callbacks))
def thread_set_async_result(async_result, value=None, exception=None):
View
@@ -1,19 +1,36 @@
import os
import unittest
+import time
from kazoo.client import ZooKeeperClient
# if this env variable is set, ZK client integration tests are run
# against the specified host list
-ENV_AZK_TEST_HOSTS = "AZK_TEST_HOSTS"
+ENV_TEST_HOSTS = "KAZOO_TEST_HOSTS"
def get_hosts_or_skip():
- if ENV_AZK_TEST_HOSTS in os.environ:
- return os.environ[ENV_AZK_TEST_HOSTS]
+ if ENV_TEST_HOSTS in os.environ:
+ return os.environ[ENV_TEST_HOSTS]
raise unittest.SkipTest("Skipping ZooKeeper test. To run, set "+
"%s env to a host list. (ex: localhost:2181)" %
- ENV_AZK_TEST_HOSTS)
+ ENV_TEST_HOSTS)
def get_client_or_skip(**kwargs):
hosts = get_hosts_or_skip()
- return ZooKeeperClient(hosts, **kwargs)
+ return ZooKeeperClient(hosts, **kwargs)
+
+def until_timeout(timeout, value=None):
+ """Returns an iterator that repeats until a timeout is reached
+
+ timeout is in seconds
+ """
+
+ start = time.time()
+
+ while True:
+ if time.time() - start >= timeout:
+ raise Exception("timed out before success!")
+ yield value
+
+
+

0 comments on commit af01338

Please sign in to comment.