Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Factor out thread-rendezvous logic from test_threads.py, and use it i…

…n test_thread_util.py
  • Loading branch information...
commit 024ed62313d0a92905b18a9d34dacc732ac89ee0 1 parent e31a07a
@ajdavis ajdavis authored
Showing with 194 additions and 108 deletions.
  1. +76 −36 test/test_thread_util.py
  2. +20 −72 test/test_threads.py
  3. +98 −0 test/utils.py
View
112 test/test_thread_util.py
@@ -29,68 +29,108 @@
if thread_util.have_greenlet:
import greenlet
-from test.utils import looplet
+from test.utils import looplet, RendezvousThread
class TestIdent(unittest.TestCase):
+ """Ensure thread_util.Ident works for threads and greenlets. This has
+ gotten intricate from refactoring: we have classes, Watched and Unwatched,
+ that implement the logic for the two child threads / greenlets. For the
+ greenlet case it's easy to ensure the two children are alive at once, so
+ we run the Watched and Unwatched logic directly. For the thread case we
+ mix in the RendezvousThread class so we're sure both children are alive
+ when they call Ident.get().
+
+ 1. Store main thread's / greenlet's id
+ 2. Start 2 child threads / greenlets
+ 3. Store their values for Ident.get()
+ 4. Children reach rendezvous point
+ 5. Children call Ident.watch()
+ 6. One of the children calls Ident.unwatch()
+ 7. Children terminate
+ 8. Assert that children got different ids from each other and from main,
+ and assert watched child's callback was executed, and that unwatched
+ child's callback was not
+ """
def _test_ident(self, use_greenlets):
ident = thread_util.create_ident(use_greenlets)
ids = set([ident.get()])
unwatched_id = []
- done = set([ident.get()]) # Start with main thread's id
+ done = set([ident.get()]) # Start with main thread's / greenlet's id
died = set()
- def watched_thread():
- my_id = ident.get()
- time.sleep(.1) # Ensure other thread starts so we don't recycle ids
- ids.add(my_id)
- self.assertFalse(ident.watching())
+ class Watched(object):
+ def __init__(self, ident):
+ self._ident = ident
- def on_died(ref):
- died.add(my_id)
+ def before_rendezvous(self):
+ self.my_id = self._ident.get()
+ ids.add(self.my_id)
- ident.watch(on_died)
- self.assertTrue(ident.watching())
- done.add(my_id)
+ def after_rendezvous(self):
+ assert not self._ident.watching()
+ self._ident.watch(lambda ref: died.add(self.my_id))
+ assert self._ident.watching()
+ done.add(self.my_id)
- def unwatched_thread():
- my_id = ident.get()
- time.sleep(.1) # Ensure other thread starts so we don't recycle ids
- unwatched_id.append(my_id)
- ids.add(my_id)
- self.assertFalse(ident.watching())
+ class Unwatched(Watched):
+ def before_rendezvous(self):
+ Watched.before_rendezvous(self)
+ unwatched_id.append(self.my_id)
- def on_died(ref):
- died.add(my_id)
-
- ident.watch(on_died)
- self.assertTrue(ident.watching())
- ident.unwatch()
- self.assertFalse(ident.watching())
- done.add(my_id)
+ def after_rendezvous(self):
+ Watched.after_rendezvous(self)
+ self._ident.unwatch()
+ assert not self._ident.watching()
if use_greenlets:
- t_watched = greenlet.greenlet(watched_thread)
- t_unwatched = greenlet.greenlet(unwatched_thread)
+ class WatchedGreenlet(Watched):
+ def run(self):
+ self.before_rendezvous()
+ self.after_rendezvous()
+
+ class UnwatchedGreenlet(Unwatched):
+ def run(self):
+ self.before_rendezvous()
+ self.after_rendezvous()
+
+ t_watched = greenlet.greenlet(WatchedGreenlet(ident).run)
+ t_unwatched = greenlet.greenlet(UnwatchedGreenlet(ident).run)
looplet([t_watched, t_unwatched])
else:
- t_watched = threading.Thread(target=watched_thread)
- t_watched.setDaemon(True)
+ class WatchedThread(Watched, RendezvousThread):
+ def __init__(self, ident, state):
+ Watched.__init__(self, ident)
+ RendezvousThread.__init__(self, state)
+
+ class UnwatchedThread(Unwatched, RendezvousThread):
+ def __init__(self, ident, state):
+ Unwatched.__init__(self, ident)
+ RendezvousThread.__init__(self, state)
+
+ state = RendezvousThread.create_shared_state(2)
+ t_watched = WatchedThread(ident, state)
t_watched.start()
- t_unwatched = threading.Thread(target=unwatched_thread)
- t_unwatched.setDaemon(True)
+ t_unwatched = UnwatchedThread(ident, state)
t_unwatched.start()
+ RendezvousThread.wait_for_rendezvous(state)
+ RendezvousThread.resume_after_rendezvous(state)
+
t_watched.join()
t_unwatched.join()
+ self.assertTrue(t_watched.passed)
+ self.assertTrue(t_unwatched.passed)
+
# Remove references, let weakref callbacks run
del t_watched
del t_unwatched
# Accessing the thread-local triggers cleanup in Python <= 2.6
+ # http://bugs.python.org/issue1868
ident.get()
self.assertEqual(3, len(ids))
self.assertEqual(3, len(done))
@@ -127,11 +167,11 @@ def _test_counter(self, use_greenlets):
done = set()
def f(n):
- for i in range(n):
+ for i in xrange(n):
self.assertEqual(i, counter.get())
self.assertEqual(i + 1, counter.inc())
- for i in range(n, 0, -1):
+ for i in xrange(n, 0, -1):
self.assertEqual(i, counter.get())
self.assertEqual(i - 1, counter.dec())
@@ -147,11 +187,11 @@ def f(n):
if use_greenlets:
greenlets = [
- greenlet.greenlet(partial(f, i)) for i in range(10)]
+ greenlet.greenlet(partial(f, i)) for i in xrange(10)]
looplet(greenlets)
else:
threads = [
- threading.Thread(target=partial(f, i)) for i in range(10)]
+ threading.Thread(target=partial(f, i)) for i in xrange(10)]
for t in threads:
t.start()
for t in threads:
View
92 test/test_threads.py
@@ -20,7 +20,7 @@
from nose.plugins.skip import SkipTest
-from test.utils import server_started_with_auth, joinall
+from test.utils import server_started_with_auth, joinall, RendezvousThread
from test.test_connection import get_connection
from pymongo.connection import Connection
from pymongo.replica_set_connection import ReplicaSetConnection
@@ -137,86 +137,37 @@ def run(self):
pass
-class FindPauseFind(threading.Thread):
+class FindPauseFind(RendezvousThread):
"""See test_server_disconnect() for details"""
- @classmethod
- def shared_state(cls, nthreads):
- class SharedState(object):
- pass
-
- state = SharedState()
-
- # Number of threads total
- state.nthreads = nthreads
-
- # Number of threads that have arrived at rendezvous point
- state.arrived_threads = 0
- state.arrived_threads_lock = threading.Lock()
-
- # set when all threads reach rendezvous
- state.ev_arrived = threading.Event()
-
- # set from outside FindPauseFind to let threads resume after
- # rendezvous
- state.ev_resume = threading.Event()
- return state
-
def __init__(self, collection, state):
- """Params: A collection, an event to signal when all threads have
- done the first find(), an event to signal when threads should resume,
- and the total number of threads
+ """Params:
+ `collection`: A collection for testing
+ `state`: A shared state object from RendezvousThread.shared_state()
"""
- super(FindPauseFind, self).__init__()
+ super(FindPauseFind, self).__init__(state)
self.collection = collection
- self.state = state
- self.passed = False
-
- # If this thread fails to terminate, don't hang the whole program
- self.setDaemon(True)
-
- def rendezvous(self):
- # pause until all threads arrive here
- s = self.state
- s.arrived_threads_lock.acquire()
- s.arrived_threads += 1
- if s.arrived_threads == s.nthreads:
- s.arrived_threads_lock.release()
- s.ev_arrived.set()
- else:
- s.arrived_threads_lock.release()
- s.ev_arrived.wait()
- def run(self):
- try:
- # acquire a socket
- list(self.collection.find())
-
- pool = get_pool(self.collection.database.connection)
- socket_info = pool._get_request_state()
- assert isinstance(socket_info, SocketInfo)
- self.request_sock = socket_info.sock
- assert not _closed(self.request_sock)
-
- # Dereference socket_info so it can potentially return to the pool
- del socket_info
- finally:
- self.rendezvous()
+ def before_rendezvous(self):
+ # acquire a socket
+ list(self.collection.find())
- # all threads have passed the rendezvous, wait for
- # test_server_disconnect() to disconnect the connection
- self.state.ev_resume.wait()
+ self.pool = get_pool(self.collection.database.connection)
+ socket_info = self.pool._get_request_state()
+ assert isinstance(socket_info, SocketInfo)
+ self.request_sock = socket_info.sock
+ assert not _closed(self.request_sock)
+ def after_rendezvous(self):
# test_server_disconnect() has closed this socket, but that's ok
# because it's not our request socket anymore
assert _closed(self.request_sock)
# if disconnect() properly closed all threads' request sockets, then
# this won't raise AutoReconnect because it will acquire a new socket
- assert self.request_sock == pool._get_request_state().sock
+ assert self.request_sock == self.pool._get_request_state().sock
list(self.collection.find())
assert self.collection.database.connection.in_request()
- assert self.request_sock != pool._get_request_state().sock
- self.passed = True
+ assert self.request_sock != self.pool._get_request_state().sock
class BaseTestThreads(object):
@@ -320,7 +271,7 @@ def test_server_disconnect(self):
assert isinstance(socket_info, SocketInfo)
request_sock = socket_info.sock
- state = FindPauseFind.shared_state(nthreads=40)
+ state = FindPauseFind.create_shared_state(nthreads=40)
threads = [
FindPauseFind(collection, state)
@@ -332,12 +283,9 @@ def test_server_disconnect(self):
t.start()
# Wait for the threads to reach the rendezvous
- state.ev_arrived.wait(10)
- self.assertTrue(state.ev_arrived.isSet(), "Thread timeout")
+ FindPauseFind.wait_for_rendezvous(state)
try:
- self.assertEqual(state.nthreads, state.arrived_threads)
-
# Simulate an event that closes all sockets, e.g. primary stepdown
for t in threads:
t.request_sock.close()
@@ -355,7 +303,7 @@ def test_server_disconnect(self):
finally:
# Let threads do a second find()
- state.ev_resume.set()
+ FindPauseFind.resume_after_rendezvous(state)
joinall(threads)
View
98 test/utils.py
@@ -15,6 +15,8 @@
"""Utilities for testing pymongo
"""
+import threading
+
from pymongo.errors import AutoReconnect
from pymongo.pool import NO_REQUEST, NO_SOCKET_YET, SocketInfo
@@ -81,6 +83,102 @@ def looplet(greenlets):
if done:
return
+class RendezvousThread(threading.Thread):
+ """A thread that starts and pauses at a rendezvous point before resuming.
+ To be used in tests that must ensure that N threads are all alive
+ simultaneously, regardless of thread-scheduling's vagaries.
+
+ 1. Write a subclass of RendezvousThread and override before_rendezvous
+ and / or after_rendezvous.
+ 2. Create a state with RendezvousThread.shared_state(N)
+ 3. Start N of your subclassed RendezvousThreads, passing the state to each
+ one's __init__
+ 4. In the main thread, call RendezvousThread.wait_for_rendezvous
+ 5. Test whatever you need to test while threads are paused at rendezvous
+ point
+ 6. In main thread, call RendezvousThread.resume_after_rendezvous
+ 7. Join all threads from main thread
+ 8. Assert that all threads' "passed" attribute is True
+ 9. Test post-conditions
+ """
+ class RendezvousState(object):
+ def __init__(self, nthreads):
+ # Number of threads total
+ self.nthreads = nthreads
+
+ # Number of threads that have arrived at rendezvous point
+ self.arrived_threads = 0
+ self.arrived_threads_lock = threading.Lock()
+
+ # Set when all threads reach rendezvous
+ self.ev_arrived = threading.Event()
+
+ # Set by resume_after_rendezvous() so threads can continue.
+ self.ev_resume = threading.Event()
+
+
+ @classmethod
+ def create_shared_state(cls, nthreads):
+ return RendezvousThread.RendezvousState(nthreads)
+
+ def before_rendezvous(self):
+ """Overridable: Do this before the rendezvous"""
+ pass
+
+ def after_rendezvous(self):
+ """Overridable: Do this after the rendezvous. If it throws no exception,
+ `passed` is set to True
+ """
+ pass
+
+ @classmethod
+ def wait_for_rendezvous(cls, state):
+ """Wait for all threads to reach rendezvous and pause there"""
+ state.ev_arrived.wait(10)
+ assert state.ev_arrived.isSet(), "Thread timeout"
+ assert state.nthreads == state.arrived_threads
+
+ @classmethod
+ def resume_after_rendezvous(cls, state):
+ """Tell all the paused threads to continue"""
+ state.ev_resume.set()
+
+ def __init__(self, state):
+ """Params:
+ `state`: A shared state object from RendezvousThread.shared_state()
+ """
+ super(RendezvousThread, self).__init__()
+ self.state = state
+ self.passed = False
+
+ # If this thread fails to terminate, don't hang the whole program
+ self.setDaemon(True)
+
+ def _rendezvous(self):
+ """Pause until all threads arrive here"""
+ s = self.state
+ s.arrived_threads_lock.acquire()
+ s.arrived_threads += 1
+ if s.arrived_threads == s.nthreads:
+ s.arrived_threads_lock.release()
+ s.ev_arrived.set()
+ else:
+ s.arrived_threads_lock.release()
+ s.ev_arrived.wait()
+
+ def run(self):
+ try:
+ self.before_rendezvous()
+ finally:
+ self._rendezvous()
+
+ # all threads have passed the rendezvous, wait for
+ # resume_after_rendezvous()
+ self.state.ev_resume.wait()
+
+ self.after_rendezvous()
+ self.passed = True
+
def read_from_which_host(
rsc,
mode,
Please sign in to comment.
Something went wrong with that request. Please try again.