Permalink
Browse files

Factor out thread-rendezvous logic from test_threads.py, and use it i…

…n test_thread_util.py
  • Loading branch information...
ajdavis committed Dec 27, 2012
1 parent e31a07a commit 024ed62313d0a92905b18a9d34dacc732ac89ee0
Showing with 194 additions and 108 deletions.
  1. +76 −36 test/test_thread_util.py
  2. +20 −72 test/test_threads.py
  3. +98 −0 test/utils.py
View
@@ -29,68 +29,108 @@
if thread_util.have_greenlet:
import greenlet
-from test.utils import looplet
+from test.utils import looplet, RendezvousThread
class TestIdent(unittest.TestCase):
+ """Ensure thread_util.Ident works for threads and greenlets. This has
+ gotten intricate from refactoring: we have classes, Watched and Unwatched,
+ that implement the logic for the two child threads / greenlets. For the
+ greenlet case it's easy to ensure the two children are alive at once, so
+ we run the Watched and Unwatched logic directly. For the thread case we
+ mix in the RendezvousThread class so we're sure both children are alive
+ when they call Ident.get().
+
+ 1. Store main thread's / greenlet's id
+ 2. Start 2 child threads / greenlets
+ 3. Store their values for Ident.get()
+ 4. Children reach rendezvous point
+ 5. Children call Ident.watch()
+ 6. One of the children calls Ident.unwatch()
+ 7. Children terminate
+ 8. Assert that children got different ids from each other and from main,
+ and assert watched child's callback was executed, and that unwatched
+ child's callback was not
+ """
def _test_ident(self, use_greenlets):
ident = thread_util.create_ident(use_greenlets)
ids = set([ident.get()])
unwatched_id = []
- done = set([ident.get()]) # Start with main thread's id
+ done = set([ident.get()]) # Start with main thread's / greenlet's id
died = set()
- def watched_thread():
- my_id = ident.get()
- time.sleep(.1) # Ensure other thread starts so we don't recycle ids
- ids.add(my_id)
- self.assertFalse(ident.watching())
+ class Watched(object):
+ def __init__(self, ident):
+ self._ident = ident
- def on_died(ref):
- died.add(my_id)
+ def before_rendezvous(self):
+ self.my_id = self._ident.get()
+ ids.add(self.my_id)
- ident.watch(on_died)
- self.assertTrue(ident.watching())
- done.add(my_id)
+ def after_rendezvous(self):
+ assert not self._ident.watching()
+ self._ident.watch(lambda ref: died.add(self.my_id))
+ assert self._ident.watching()
+ done.add(self.my_id)
- def unwatched_thread():
- my_id = ident.get()
- time.sleep(.1) # Ensure other thread starts so we don't recycle ids
- unwatched_id.append(my_id)
- ids.add(my_id)
- self.assertFalse(ident.watching())
+ class Unwatched(Watched):
+ def before_rendezvous(self):
+ Watched.before_rendezvous(self)
+ unwatched_id.append(self.my_id)
- def on_died(ref):
- died.add(my_id)
-
- ident.watch(on_died)
- self.assertTrue(ident.watching())
- ident.unwatch()
- self.assertFalse(ident.watching())
- done.add(my_id)
+ def after_rendezvous(self):
+ Watched.after_rendezvous(self)
+ self._ident.unwatch()
+ assert not self._ident.watching()
if use_greenlets:
- t_watched = greenlet.greenlet(watched_thread)
- t_unwatched = greenlet.greenlet(unwatched_thread)
+ class WatchedGreenlet(Watched):
+ def run(self):
+ self.before_rendezvous()
+ self.after_rendezvous()
+
+ class UnwatchedGreenlet(Unwatched):
+ def run(self):
+ self.before_rendezvous()
+ self.after_rendezvous()
+
+ t_watched = greenlet.greenlet(WatchedGreenlet(ident).run)
+ t_unwatched = greenlet.greenlet(UnwatchedGreenlet(ident).run)
looplet([t_watched, t_unwatched])
else:
- t_watched = threading.Thread(target=watched_thread)
- t_watched.setDaemon(True)
+ class WatchedThread(Watched, RendezvousThread):
+ def __init__(self, ident, state):
+ Watched.__init__(self, ident)
+ RendezvousThread.__init__(self, state)
+
+ class UnwatchedThread(Unwatched, RendezvousThread):
+ def __init__(self, ident, state):
+ Unwatched.__init__(self, ident)
+ RendezvousThread.__init__(self, state)
+
+ state = RendezvousThread.create_shared_state(2)
+ t_watched = WatchedThread(ident, state)
t_watched.start()
- t_unwatched = threading.Thread(target=unwatched_thread)
- t_unwatched.setDaemon(True)
+ t_unwatched = UnwatchedThread(ident, state)
t_unwatched.start()
+ RendezvousThread.wait_for_rendezvous(state)
+ RendezvousThread.resume_after_rendezvous(state)
+
t_watched.join()
t_unwatched.join()
+ self.assertTrue(t_watched.passed)
+ self.assertTrue(t_unwatched.passed)
+
# Remove references, let weakref callbacks run
del t_watched
del t_unwatched
# Accessing the thread-local triggers cleanup in Python <= 2.6
+ # http://bugs.python.org/issue1868
ident.get()
self.assertEqual(3, len(ids))
self.assertEqual(3, len(done))
@@ -127,11 +167,11 @@ def _test_counter(self, use_greenlets):
done = set()
def f(n):
- for i in range(n):
+ for i in xrange(n):
self.assertEqual(i, counter.get())
self.assertEqual(i + 1, counter.inc())
- for i in range(n, 0, -1):
+ for i in xrange(n, 0, -1):
self.assertEqual(i, counter.get())
self.assertEqual(i - 1, counter.dec())
@@ -147,11 +187,11 @@ def f(n):
if use_greenlets:
greenlets = [
- greenlet.greenlet(partial(f, i)) for i in range(10)]
+ greenlet.greenlet(partial(f, i)) for i in xrange(10)]
looplet(greenlets)
else:
threads = [
- threading.Thread(target=partial(f, i)) for i in range(10)]
+ threading.Thread(target=partial(f, i)) for i in xrange(10)]
for t in threads:
t.start()
for t in threads:
View
@@ -20,7 +20,7 @@
from nose.plugins.skip import SkipTest
-from test.utils import server_started_with_auth, joinall
+from test.utils import server_started_with_auth, joinall, RendezvousThread
from test.test_connection import get_connection
from pymongo.connection import Connection
from pymongo.replica_set_connection import ReplicaSetConnection
@@ -137,86 +137,37 @@ def run(self):
pass
-class FindPauseFind(threading.Thread):
+class FindPauseFind(RendezvousThread):
"""See test_server_disconnect() for details"""
- @classmethod
- def shared_state(cls, nthreads):
- class SharedState(object):
- pass
-
- state = SharedState()
-
- # Number of threads total
- state.nthreads = nthreads
-
- # Number of threads that have arrived at rendezvous point
- state.arrived_threads = 0
- state.arrived_threads_lock = threading.Lock()
-
- # set when all threads reach rendezvous
- state.ev_arrived = threading.Event()
-
- # set from outside FindPauseFind to let threads resume after
- # rendezvous
- state.ev_resume = threading.Event()
- return state
-
def __init__(self, collection, state):
- """Params: A collection, an event to signal when all threads have
- done the first find(), an event to signal when threads should resume,
- and the total number of threads
+ """Params:
+ `collection`: A collection for testing
+ `state`: A shared state object from RendezvousThread.shared_state()
"""
- super(FindPauseFind, self).__init__()
+ super(FindPauseFind, self).__init__(state)
self.collection = collection
- self.state = state
- self.passed = False
-
- # If this thread fails to terminate, don't hang the whole program
- self.setDaemon(True)
-
- def rendezvous(self):
- # pause until all threads arrive here
- s = self.state
- s.arrived_threads_lock.acquire()
- s.arrived_threads += 1
- if s.arrived_threads == s.nthreads:
- s.arrived_threads_lock.release()
- s.ev_arrived.set()
- else:
- s.arrived_threads_lock.release()
- s.ev_arrived.wait()
- def run(self):
- try:
- # acquire a socket
- list(self.collection.find())
-
- pool = get_pool(self.collection.database.connection)
- socket_info = pool._get_request_state()
- assert isinstance(socket_info, SocketInfo)
- self.request_sock = socket_info.sock
- assert not _closed(self.request_sock)
-
- # Dereference socket_info so it can potentially return to the pool
- del socket_info
- finally:
- self.rendezvous()
+ def before_rendezvous(self):
+ # acquire a socket
+ list(self.collection.find())
- # all threads have passed the rendezvous, wait for
- # test_server_disconnect() to disconnect the connection
- self.state.ev_resume.wait()
+ self.pool = get_pool(self.collection.database.connection)
+ socket_info = self.pool._get_request_state()
+ assert isinstance(socket_info, SocketInfo)
+ self.request_sock = socket_info.sock
+ assert not _closed(self.request_sock)
+ def after_rendezvous(self):
# test_server_disconnect() has closed this socket, but that's ok
# because it's not our request socket anymore
assert _closed(self.request_sock)
# if disconnect() properly closed all threads' request sockets, then
# this won't raise AutoReconnect because it will acquire a new socket
- assert self.request_sock == pool._get_request_state().sock
+ assert self.request_sock == self.pool._get_request_state().sock
list(self.collection.find())
assert self.collection.database.connection.in_request()
- assert self.request_sock != pool._get_request_state().sock
- self.passed = True
+ assert self.request_sock != self.pool._get_request_state().sock
class BaseTestThreads(object):
@@ -320,7 +271,7 @@ def test_server_disconnect(self):
assert isinstance(socket_info, SocketInfo)
request_sock = socket_info.sock
- state = FindPauseFind.shared_state(nthreads=40)
+ state = FindPauseFind.create_shared_state(nthreads=40)
threads = [
FindPauseFind(collection, state)
@@ -332,12 +283,9 @@ def test_server_disconnect(self):
t.start()
# Wait for the threads to reach the rendezvous
- state.ev_arrived.wait(10)
- self.assertTrue(state.ev_arrived.isSet(), "Thread timeout")
+ FindPauseFind.wait_for_rendezvous(state)
try:
- self.assertEqual(state.nthreads, state.arrived_threads)
-
# Simulate an event that closes all sockets, e.g. primary stepdown
for t in threads:
t.request_sock.close()
@@ -355,7 +303,7 @@ def test_server_disconnect(self):
finally:
# Let threads do a second find()
- state.ev_resume.set()
+ FindPauseFind.resume_after_rendezvous(state)
joinall(threads)
Oops, something went wrong.

0 comments on commit 024ed62

Please sign in to comment.