Browse files

add Client.resubmit for re-running tasks

closes gh-411

* allow `content` in session.serialize to be a unicode object, because mongo+JSON cannot be relied upon to produce encoded bytes.
  • Loading branch information...
1 parent 6549d09 commit 0c043a69bbbbe986b0ed6c602afbec7d8e386782 @minrk committed May 4, 2011
View
62 IPython/parallel/client/client.py
@@ -1041,6 +1041,68 @@ def get_result(self, indices_or_msg_ids=None, block=None):
ar.wait()
return ar
+
+ @spin_first
+ def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
+ """Resubmit one or more tasks.
+
+ in-flight tasks may not be resubmitted.
+
+ Parameters
+ ----------
+
+ indices_or_msg_ids : integer history index, str msg_id, or list of either
+ The indices or msg_ids of indices to be retrieved
+
+ block : bool
+ Whether to wait for the result to be done
+
+ Returns
+ -------
+
+ AsyncHubResult
+ A subclass of AsyncResult that retrieves results from the Hub
+
+ """
+ block = self.block if block is None else block
+ if indices_or_msg_ids is None:
+ indices_or_msg_ids = -1
+
+ if not isinstance(indices_or_msg_ids, (list,tuple)):
+ indices_or_msg_ids = [indices_or_msg_ids]
+
+ theids = []
+ for id in indices_or_msg_ids:
+ if isinstance(id, int):
+ id = self.history[id]
+ if not isinstance(id, str):
+ raise TypeError("indices must be str or int, not %r"%id)
+ theids.append(id)
+
+ for msg_id in theids:
+ self.outstanding.discard(msg_id)
+ if msg_id in self.history:
+ self.history.remove(msg_id)
+ self.results.pop(msg_id, None)
+ self.metadata.pop(msg_id, None)
+ content = dict(msg_ids = theids)
+
+ self.session.send(self._query_socket, 'resubmit_request', content)
+
+ zmq.select([self._query_socket], [], [])
+ idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
+ if self.debug:
+ pprint(msg)
+ content = msg['content']
+ if content['status'] != 'ok':
+ raise self._unwrap_exception(content)
+
+ ar = AsyncHubResult(self, msg_ids=theids)
+
+ if block:
+ ar.wait()
+
+ return ar
@spin_first
def result_status(self, msg_ids, status_only=True):
View
127 IPython/parallel/controller/hub.py
@@ -268,8 +268,15 @@ def construct_hub(self):
}
self.log.debug("Hub engine addrs: %s"%self.engine_info)
self.log.debug("Hub client addrs: %s"%self.client_info)
+
+ # resubmit stream
+ r = ZMQStream(ctx.socket(zmq.XREQ), loop)
+ url = util.disambiguate_url(self.client_info['task'][-1])
+ r.setsockopt(zmq.IDENTITY, self.session.session)
+ r.connect(url)
+
self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
- query=q, notifier=n, db=self.db,
+ query=q, notifier=n, resubmit=r, db=self.db,
engine_info=self.engine_info, client_info=self.client_info,
logname=self.log.name)
@@ -315,8 +322,9 @@ class Hub(LoggingFactory):
loop=Instance(ioloop.IOLoop)
query=Instance(ZMQStream)
monitor=Instance(ZMQStream)
- heartmonitor=Instance(HeartMonitor)
notifier=Instance(ZMQStream)
+ resubmit=Instance(ZMQStream)
+ heartmonitor=Instance(HeartMonitor)
db=Instance(object)
client_info=Dict()
engine_info=Dict()
@@ -379,6 +387,9 @@ def __init__(self, **kwargs):
'connection_request': self.connection_request,
}
+ # ignore resubmit replies
+ self.resubmit.on_recv(lambda msg: None, copy=False)
+
self.log.info("hub::created hub")
@property
@@ -452,48 +463,49 @@ def _validate_targets(self, targets):
def dispatch_monitor_traffic(self, msg):
"""all ME and Task queue messages come through here, as well as
IOPub traffic."""
- self.log.debug("monitor traffic: %s"%msg[:2])
+ self.log.debug("monitor traffic: %r"%msg[:2])
switch = msg[0]
idents, msg = self.session.feed_identities(msg[1:])
if not idents:
- self.log.error("Bad Monitor Message: %s"%msg)
+ self.log.error("Bad Monitor Message: %r"%msg)
return
handler = self.monitor_handlers.get(switch, None)
if handler is not None:
handler(idents, msg)
else:
- self.log.error("Invalid monitor topic: %s"%switch)
+ self.log.error("Invalid monitor topic: %r"%switch)
def dispatch_query(self, msg):
"""Route registration requests and queries from clients."""
idents, msg = self.session.feed_identities(msg)
if not idents:
- self.log.error("Bad Query Message: %s"%msg)
+ self.log.error("Bad Query Message: %r"%msg)
return
client_id = idents[0]
try:
msg = self.session.unpack_message(msg, content=True)
except:
content = error.wrap_exception()
- self.log.error("Bad Query Message: %s"%msg, exc_info=True)
+ self.log.error("Bad Query Message: %r"%msg, exc_info=True)
self.session.send(self.query, "hub_error", ident=client_id,
content=content)
return
# print client_id, header, parent, content
#switch on message type:
msg_type = msg['msg_type']
- self.log.info("client::client %s requested %s"%(client_id, msg_type))
+ self.log.info("client::client %r requested %r"%(client_id, msg_type))
handler = self.query_handlers.get(msg_type, None)
try:
- assert handler is not None, "Bad Message Type: %s"%msg_type
+ assert handler is not None, "Bad Message Type: %r"%msg_type
except:
content = error.wrap_exception()
- self.log.error("Bad Message Type: %s"%msg_type, exc_info=True)
+ self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
self.session.send(self.query, "hub_error", ident=client_id,
content=content)
return
+
else:
handler(idents, msg)
@@ -560,9 +572,9 @@ def save_queue_request(self, idents, msg):
# it's posible iopub arrived first:
existing = self.db.get_record(msg_id)
for key,evalue in existing.iteritems():
- rvalue = record[key]
+ rvalue = record.get(key, None)
if evalue and rvalue and evalue != rvalue:
- self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
+ self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
elif evalue and not rvalue:
record[key] = evalue
self.db.update_record(msg_id, record)
@@ -648,10 +660,22 @@ def save_task_request(self, idents, msg):
try:
# it's posible iopub arrived first:
existing = self.db.get_record(msg_id)
+ if existing['resubmitted']:
+ for key in ('submitted', 'client_uuid', 'buffers'):
+ # don't clobber these keys on resubmit
+ # submitted and client_uuid should be different
+ # and buffers might be big, and shouldn't have changed
+ record.pop(key)
+ # still check content,header which should not change
+ # but are not expensive to compare as buffers
+
for key,evalue in existing.iteritems():
- rvalue = record[key]
+ if key.endswith('buffers'):
+ # don't compare buffers
+ continue
+ rvalue = record.get(key, None)
if evalue and rvalue and evalue != rvalue:
- self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
+ self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
elif evalue and not rvalue:
record[key] = evalue
self.db.update_record(msg_id, record)
@@ -1075,9 +1099,68 @@ def purge_results(self, client_id, msg):
self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
- def resubmit_task(self, client_id, msg, buffers):
- """Resubmit a task."""
- raise NotImplementedError
+ def resubmit_task(self, client_id, msg):
+ """Resubmit one or more tasks."""
+ def finish(reply):
+ self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
+
+ content = msg['content']
+ msg_ids = content['msg_ids']
+ reply = dict(status='ok')
+ try:
+ records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
+ 'header', 'content', 'buffers'])
+ except Exception:
+ self.log.error('db::db error finding tasks to resubmit', exc_info=True)
+ return finish(error.wrap_exception())
+
+ # validate msg_ids
+ found_ids = [ rec['msg_id'] for rec in records ]
+ invalid_ids = filter(lambda m: m in self.pending, found_ids)
+ if len(records) > len(msg_ids):
+ try:
+ raise RuntimeError("DB appears to be in an inconsistent state."
+ "More matching records were found than should exist")
+ except Exception:
+ return finish(error.wrap_exception())
+ elif len(records) < len(msg_ids):
+ missing = [ m for m in msg_ids if m not in found_ids ]
+ try:
+ raise KeyError("No such msg(s): %s"%missing)
+ except KeyError:
+ return finish(error.wrap_exception())
+ elif invalid_ids:
+ msg_id = invalid_ids[0]
+ try:
+ raise ValueError("Task %r appears to be inflight"%(msg_id))
+ except Exception:
+ return finish(error.wrap_exception())
+
+ # clear the existing records
+ rec = empty_record()
+ map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
+ rec['resubmitted'] = datetime.now()
+ rec['queue'] = 'task'
+ rec['client_uuid'] = client_id[0]
+ try:
+ for msg_id in msg_ids:
+ self.all_completed.discard(msg_id)
+ self.db.update_record(msg_id, rec)
+ except Exception:
+ self.log.error('db::db error upating record', exc_info=True)
+ reply = error.wrap_exception()
+ else:
+ # send the messages
+ for rec in records:
+ header = rec['header']
+ msg = self.session.msg(header['msg_type'])
+ msg['content'] = rec['content']
+ msg['header'] = header
+ msg['msg_id'] = rec['msg_id']
+ self.session.send(self.resubmit, msg, buffers=rec['buffers'])
+
+ finish(dict(status='ok'))
+
def _extract_record(self, rec):
"""decompose a TaskRecord dict into subsection of reply for get_result"""
@@ -1124,12 +1207,20 @@ def get_results(self, client_id, msg):
for msg_id in msg_ids:
if msg_id in self.pending:
pending.append(msg_id)
- elif msg_id in self.all_completed or msg_id in records:
+ elif msg_id in self.all_completed:
completed.append(msg_id)
if not statusonly:
c,bufs = self._extract_record(records[msg_id])
content[msg_id] = c
buffers.extend(bufs)
+ elif msg_id in records:
+ if rec['completed']:
+ completed.append(msg_id)
+ c,bufs = self._extract_record(records[msg_id])
+ content[msg_id] = c
+ buffers.extend(bufs)
+ else:
+ pending.append(msg_id)
else:
try:
raise KeyError('No such message: '+msg_id)
View
3 IPython/parallel/streamsession.py
@@ -186,6 +186,9 @@ def serialize(self, msg, ident=None):
elif isinstance(content, bytes):
# content is already packed, as in a relayed message
pass
+ elif isinstance(content, unicode):
+ # should be bytes, but JSON often spits out unicode
+ content = content.encode('utf8')
else:
raise TypeError("Content incorrect type: %s"%type(content))
View
23 IPython/parallel/tests/test_client.py
@@ -212,3 +212,26 @@ def test_hub_history(self):
time.sleep(0.25)
self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
+ def test_resubmit(self):
+ def f():
+ import random
+ return random.random()
+ v = self.client.load_balanced_view()
+ ar = v.apply_async(f)
+ r1 = ar.get(1)
+ ahr = self.client.resubmit(ar.msg_ids)
+ r2 = ahr.get(1)
+ self.assertFalse(r1 == r2)
+
+ def test_resubmit_inflight(self):
+ """ensure ValueError on resubmit of inflight task"""
+ v = self.client.load_balanced_view()
+ ar = v.apply_async(time.sleep,1)
+ # give the message a chance to arrive
+ time.sleep(0.2)
+ self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
+ ar.get(2)
+
+ def test_resubmit_badkey(self):
+ """ensure KeyError on resubmit of nonexistant task"""
+ self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
View
2 IPython/parallel/tests/test_lbview.py
@@ -36,7 +36,7 @@ def test_z_crash_task(self):
"""test graceful handling of engine death (balanced)"""
# self.add_engines(1)
ar = self.view.apply_async(crash)
- self.assertRaisesRemote(error.EngineError, ar.get)
+ self.assertRaisesRemote(error.EngineError, ar.get, 10)
eid = ar.engine_id
tic = time.time()
while eid in self.client.ids and time.time()-tic < 5:

0 comments on commit 0c043a6

Please sign in to comment.