Permalink
Browse files

Use gevent+zeromq to handle jobs

  • Loading branch information...
dcramer committed May 15, 2012
1 parent 6aeff1f commit 50c563aeafc617c6d386560a05931925f635bd1b
View
@@ -5,7 +5,7 @@
setup(
name="taskmaster",
license='Apache License 2.0',
- version="0.3.1",
+ version="0.4.0",
description="",
author="David Cramer",
author_email="dcramer@gmail.com",
@@ -21,6 +21,9 @@
},
install_requires=[
'progressbar',
+ 'gevent',
+ 'gevent_zeromq',
+ # 'pyzmq-static',
],
tests_require=[
'unittest2',
@@ -6,35 +6,48 @@
:license: Apache License 2.0, see LICENSE for more details.
"""
-from multiprocessing.managers import BaseManager
-from threading import Thread
+import cPickle as pickle
+from gevent_zeromq import zmq
from taskmaster.controller import Controller
-import Queue
+from gevent.queue import Queue, Empty
-class QueueManager(BaseManager):
- pass
-
-
-class QueueServer(Thread):
- def __init__(self, host, port, size=None, authkey=None):
- Thread.__init__(self)
+class Server(object):
+ def __init__(self, host, port, size=None):
self.daemon = True
self.started = False
- self.queue = Queue.Queue(maxsize=size)
+ self.queue = Queue(maxsize=size)
+ self.address = 'tcp://%s:%s' % (host, port)
- QueueManager.register('get_queue', callable=lambda: self.queue)
+ def start(self):
+ self.started = True
+ self.context = context = zmq.Context(1)
- self.manager = QueueManager(address=(host, int(port)), authkey=authkey)
+ self.server = server = context.socket(zmq.REP)
+ server.bind(self.address)
- def run(self):
- self.started = True
- server = self.manager.get_server()
- print "Taskmaster server running on %r" % ':'.join(map(str, server.address))
- server.serve_forever()
+ print "Taskmaster server running on %r" % self.address
+
+ while self.started:
+ request = server.recv()
+ if request == 'GET':
+ try:
+ job = self.queue.get_nowait()
+ except Empty:
+ server.send('WAIT')
+ continue
+
+ server.send('OK %s' % (pickle.dumps(job),))
+ elif request == 'DONE':
+ self.queue.task_done()
+ server.send('OK')
+ else:
+ server.send('ERROR Unrecognized command')
+
+ self.shutdown()
def put_job(self, job):
- self.queue.put(job)
+ self.queue.put_nowait(job)
def first_job(self):
return self.queue.queue[0]
@@ -46,16 +59,17 @@ def is_alive(self):
return self.started
def shutdown(self):
- # TODO:
- # if self.started:
- # self.manager.shutdown()
+ if not self.started:
+ return
+ self.server.close()
+ self.context.term()
self.started = False
-def run(target, reset=False, size=10000, host='0.0.0.0:3050', key='taskmaster'):
+def run(target, reset=False, size=10000, host='0.0.0.0:3050'):
host, port = host.split(':')
- server = QueueServer(host, int(port), size=size, authkey=key)
+ server = Server(host, int(port), size=size)
controller = Controller(server, target)
if reset:
@@ -69,7 +83,6 @@ def main():
parser = optparse.OptionParser()
parser.add_option("--host", dest="host", default='0.0.0.0:3050')
parser.add_option("--size", dest="size", default='10000', type=int)
- parser.add_option("--key", dest="key", default='taskmaster')
parser.add_option("--reset", dest="reset", default=False, action='store_true')
(options, args) = parser.parse_args()
if len(args) != 1:
@@ -6,42 +6,25 @@
:license: Apache License 2.0, see LICENSE for more details.
"""
-from multiprocessing.managers import BaseManager
-from taskmaster.util import import_target
-from taskmaster.workers import ThreadPool
-import time
-
-class QueueManager(BaseManager):
- pass
-
-
-def run(target, host='0.0.0.0:3050', key='taskmaster', threads=1):
- QueueManager.register('get_queue')
+def run(target, host='0.0.0.0:3050', progressbar=True):
+ from taskmaster.consumer import Consumer
+ from taskmaster.util import import_target
host, port = host.split(':')
- m = QueueManager(address=(host, int(port)), authkey=key)
- m.connect()
- queue = m.get_queue()
-
target = import_target(target, 'handle_job')
- pool = ThreadPool(queue, target, size=threads)
- while pool.is_alive() and not queue.empty():
- time.sleep(0)
-
- pool.join()
+ client = Consumer(host, port, progressbar=progressbar)
+ client.start(target)
def main():
import optparse
import sys
parser = optparse.OptionParser()
parser.add_option("--host", dest="host", default='0.0.0.0:3050')
- parser.add_option("--key", dest="key", default='taskmaster')
- parser.add_option("--threads", dest="threads", default=1, type=int)
- # parser.add_option("--procs", dest="procs", default=1, type=int)
+ parser.add_option("--progress", dest="progressbar", action="store_true", default=False)
(options, args) = parser.parse_args()
if len(args) != 1:
print 'Usage: tm-slave <callback>'
@@ -27,14 +27,11 @@ def main():
import sys
parser = optparse.OptionParser()
parser.add_option("--host", dest="host", default='0.0.0.0:3050')
- parser.add_option("--key", dest="key", default='taskmaster')
- parser.add_option("--threads", dest="threads", default=1, type=int)
- # parser.add_option("--procs", dest="procs", default=1, type=int)
(options, args) = parser.parse_args()
if len(args) != 2:
print 'Usage: tm-spawn <callback> <processes>'
sys.exit(1)
- sys.exit(run(args[0], procs=int(args[1]), **options.__dict__))
+ sys.exit(run(args[0], procs=int(args[1]), progressbar=False, **options.__dict__))
if __name__ == '__main__':
main()
View
@@ -0,0 +1,122 @@
+"""
+taskmaster.consumer
+~~~~~~~~~~~~~~~~~~~
+
+:copyright: (c) 2010 DISQUS.
+:license: Apache License 2.0, see LICENSE for more details.
+"""
+
+import cPickle as pickle
+import gevent
+from gevent_zeromq import zmq
+from gevent.queue import Queue, Empty
+
+
+class Worker(object):
+ def __init__(self, client, target):
+ self.client = client
+ self.target = target
+
+ def run(self):
+ self.running = True
+ while self.running:
+ job_id, job = self.client.get_job()
+
+ try:
+ self.target(job)
+ except KeyboardInterrupt:
+ return
+ finally:
+ self.client.task_done()
+
+
+class Consumer(object):
+ def __init__(self, host, port, progressbar=True, request_timeout=2500):
+ self.daemon = True
+ self.started = False
+ self.address = 'tcp://%s:%s' % (host, port)
+ self.request_timeout = request_timeout
+ self.queue = Queue()
+ self._wants_job = False
+
+ if progressbar:
+ self.pbar = type(self).get_progressbar()
+ else:
+ self.pbar = None
+
+ @classmethod
+ def get_progressbar(cls):
+ from taskmaster.progressbar import Counter, Speed, Timer, ProgressBar, UnknownLength
+
+ widgets = ['Current Job: ', Counter(), ' | ', Speed(), ' | ', Timer()]
+
+ pbar = ProgressBar(widgets=widgets, maxval=UnknownLength)
+
+ return pbar
+
+ def start(self, target):
+ self.started = True
+ self.tasks_completed = 0
+
+ self.context = context = zmq.Context(1)
+ self.client = client = context.socket(zmq.REQ)
+ self.poll = poll = zmq.Poller()
+
+ client.connect(self.address)
+ poll.register(client, zmq.POLLIN)
+
+ worker = Worker(self, target)
+ gevent.spawn(worker.run)
+
+ print "Connecting to server on %r" % self.address
+
+ if self.pbar:
+ self.pbar.start()
+
+ while True:
+ # If the queue has items in it, we just loop
+ if not self._wants_job:
+ gevent.sleep(0)
+ continue
+
+ client.send('GET')
+ socks = dict(poll.poll(self.request_timeout))
+ if socks.get(client) != zmq.POLLIN:
+ # server connection closed
+ break
+
+ reply = client.recv()
+ if not reply:
+ break
+
+ # Reply can be "WAIT", "OK", or "ERROR"
+ if reply.startswith('OK '):
+ self._wants_job = False
+ job = pickle.loads(reply[3:])
+ self.queue.put(job)
+
+ self.shutdown()
+
+ def get_job(self):
+ self._wants_job = True
+
+ while True:
+ try:
+ return self.queue.get_nowait()
+ except Empty:
+ gevent.sleep(0)
+
+ def task_done(self):
+ self.tasks_completed += 1
+ if self.pbar:
+ self.pbar.update(self.tasks_completed)
+
+ def shutdown(self):
+ if not self.started:
+ return
+ self.poll.unregister(self.client)
+ self.client.close()
+ self.context.term()
+ if self.pbar:
+ self.pbar.finish()
+ self.started = False
@@ -6,11 +6,10 @@
:license: Apache License 2.0, see LICENSE for more details.
"""
-import sys
-import time
import cPickle as pickle
+import gevent
+import sys
from os import path, unlink
-from threading import Thread
from taskmaster.util import import_target
@@ -50,6 +49,8 @@ def read_state(self):
with open(self.state_file, 'r') as fp:
try:
return pickle.load(fp)
+ except EOFError:
+ pass
except Exception, e:
print "There was an error reading from state file. Ignoring and continuing without."
import traceback
@@ -72,19 +73,19 @@ def update_state(self, job_id, job, fp=None):
fp.seek(0)
pickle.dump(data, fp)
+ if self.pbar:
+ self.pbar.update(job_id)
+
def state_writer(self):
last_job_id = None
with open(self.state_file, 'w') as fp:
while self.server.is_alive():
- time.sleep(0)
+ gevent.sleep(0.01)
try:
job_id, job = self.server.first_job()
except IndexError:
continue
- if self.pbar:
- self.pbar.update(job_id)
-
if not job or job_id == last_job_id:
continue
@@ -105,21 +106,23 @@ def start(self):
else:
start_id = 0
- self.server.start()
+ gevent.spawn(self.server.start)
+
+ gevent.sleep(0)
+
if self.pbar:
self.pbar.start()
+ self.pbar.update(start_id)
- state_writer = Thread(target=self.state_writer)
- state_writer.daemon = True
- state_writer.start()
+ state_writer = gevent.spawn(self.state_writer)
job_id, job = (None, None)
for job_id, job in enumerate(self.target(**kwargs), start_id):
self.server.put_job((job_id, job))
- time.sleep(0)
+ gevent.sleep(0)
while self.server.has_work():
- time.sleep(0)
+ gevent.sleep(0)
self.server.shutdown()
state_writer.join(1)
@@ -10,7 +10,7 @@
def get_jobs(last=0):
# last_job would be sent if state was resumed
# from a previous run
- for i in xrange(last, 100000):
+ for i in xrange(last, 10000):
yield i
Oops, something went wrong.

0 comments on commit 50c563a

Please sign in to comment.