Skip to content

Commit

Permalink
improved client.get_results() behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk committed Apr 8, 2011
1 parent a49b646 commit 6f48516
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 61 deletions.
2 changes: 1 addition & 1 deletion IPython/zmq/parallel/asyncresult.py
Expand Up @@ -40,7 +40,7 @@ def _reconstruct_result(self, res):
Override me in subclasses for turning a list of results
into the expected form.
"""
if len(res) == 1:
if len(self.msg_ids) == 1:
return res[0]
else:
return res
Expand Down
104 changes: 82 additions & 22 deletions IPython/zmq/parallel/client.py
Expand Up @@ -14,6 +14,7 @@
import time
from getpass import getpass
from pprint import pprint
from datetime import datetime

import zmq
from zmq.eventloop import ioloop, zmqstream
Expand All @@ -29,6 +30,7 @@
import map as Map
from asyncresult import AsyncResult, AsyncMapResult
from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
from util import ReverseDict

#--------------------------------------------------------------------------
# helpers for implementing old MEC API via client.apply
Expand Down Expand Up @@ -83,6 +85,11 @@ def defaultblock(f, self, *args, **kwargs):
self.block = saveblock
return ret


#--------------------------------------------------------------------------
# Classes
#--------------------------------------------------------------------------

class AbortedTask(object):
"""A basic wrapper object describing an aborted task."""
def __init__(self, msg_id):
Expand Down Expand Up @@ -233,10 +240,11 @@ def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, de
tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
else:
self._registration_socket.connect(addr)
self._engines = {}
self._engines = ReverseDict()
self._ids = set()
self.outstanding=set()
self.results = {}
self.metadata = {}
self.history = []
self.debug = debug
self.session.debug = debug
Expand Down Expand Up @@ -342,9 +350,27 @@ def _unregister_engine(self, msg):
if eid in self._ids:
self._ids.remove(eid)
self._engines.pop(eid)

#
def _build_metadata(self, header, parent, content):
md = {'msg_id' : parent['msg_id'],
'submitted' : datetime.strptime(parent['date'], ss.ISO8601),
'started' : datetime.strptime(header['started'], ss.ISO8601),
'completed' : datetime.strptime(header['date'], ss.ISO8601),
'received' : datetime.now(),
'engine_uuid' : header['engine'],
'engine_id' : self._engines.get(header['engine'], None),
'follow' : parent['follow'],
'after' : parent['after'],
'status' : content['status']
}
return md

def _handle_execute_reply(self, msg):
"""Save the reply to an execute_request into our results."""
"""Save the reply to an execute_request into our results.
execute messages are never actually used. apply is used instead.
"""

parent = msg['parent_header']
msg_id = parent['msg_id']
if msg_id not in self.outstanding:
Expand All @@ -362,8 +388,12 @@ def _handle_apply_reply(self, msg):
else:
self.outstanding.remove(msg_id)
content = msg['content']
header = msg['header']

self.metadata[msg_id] = self._build_metadata(header, parent, content)

if content['status'] == 'ok':
self.results[msg_id] = ss.unserialize_object(msg['buffers'])
self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0]
elif content['status'] == 'aborted':
self.results[msg_id] = error.AbortedTask(msg_id)
elif content['status'] == 'resubmitted':
Expand All @@ -372,10 +402,8 @@ def _handle_apply_reply(self, msg):
else:
e = ss.unwrap_exception(content)
e_uuid = e.engine_info['engineid']
for k,v in self._engines.iteritems():
if v == e_uuid:
e.engine_info['engineid'] = k
break
eid = self._engines[e_uuid]
e.engine_info['engineid'] = eid
self.results[msg_id] = e

def _flush_notifications(self):
Expand Down Expand Up @@ -882,6 +910,13 @@ def get_results(self, msg_ids, status_only=False):
status_only : bool (default: False)
if False:
return the actual results
Returns
-------
results : dict
There will always be the keys 'pending' and 'completed', which will
be lists of msg_ids.
"""
if not isinstance(msg_ids, (list,tuple)):
msg_ids = [msg_ids]
Expand All @@ -895,11 +930,12 @@ def get_results(self, msg_ids, status_only=False):

completed = []
local_results = {}
for msg_id in list(theids):
if msg_id in self.results:
completed.append(msg_id)
local_results[msg_id] = self.results[msg_id]
theids.remove(msg_id)
# temporarily disable local shortcut
# for msg_id in list(theids):
# if msg_id in self.results:
# completed.append(msg_id)
# local_results[msg_id] = self.results[msg_id]
# theids.remove(msg_id)

if theids: # some not locally cached
content = dict(msg_ids=theids, status_only=status_only)
Expand All @@ -911,16 +947,40 @@ def get_results(self, msg_ids, status_only=False):
content = msg['content']
if content['status'] != 'ok':
raise ss.unwrap_exception(content)
buffers = msg['buffers']
else:
content = dict(completed=[],pending=[])
if not status_only:
# load cached results into result:
content['completed'].extend(completed)
content.update(local_results)
# update cache with results:
for msg_id in msg_ids:
if msg_id in content['completed']:
self.results[msg_id] = content[msg_id]

content['completed'].extend(completed)

if status_only:
return content

failures = []
# load cached results into result:
content.update(local_results)
# update cache with results:
for msg_id in sorted(theids):
if msg_id in content['completed']:
rec = content[msg_id]
parent = rec['header']
header = rec['result_header']
rcontent = rec['result_content']
if isinstance(rcontent, str):
rcontent = self.session.unpack(rcontent)

self.metadata[msg_id] = self._build_metadata(header, parent, rcontent)

if rcontent['status'] == 'ok':
res,buffers = ss.unserialize_object(buffers)
else:
res = ss.unwrap_exception(rcontent)
failures.append(res)

self.results[msg_id] = res
content[msg_id] = res

error.collect_exceptions(failures, "get_results")
return content

@spinfirst
Expand All @@ -945,7 +1005,7 @@ def queue_status(self, targets=None, verbose=False):
status = content.pop('status')
if status != 'ok':
raise ss.unwrap_exception(content)
return content
return ss.rekey(content)

@spinfirst
def purge_results(self, msg_ids=[], targets=[]):
Expand Down
45 changes: 15 additions & 30 deletions IPython/zmq/parallel/controller.py
Expand Up @@ -47,33 +47,6 @@
def _passer(*args, **kwargs):
return

class ReverseDict(dict):
"""simple double-keyed subset of dict methods."""

def __init__(self, *args, **kwargs):
dict.__init__(self, *args, **kwargs)
self.reverse = dict()
for key, value in self.iteritems():
self.reverse[value] = key

def __getitem__(self, key):
try:
return dict.__getitem__(self, key)
except KeyError:
return self.reverse[key]

def __setitem__(self, key, value):
if key in self.reverse:
raise KeyError("Can't have key %r on both sides!"%key)
dict.__setitem__(self, key, value)
self.reverse[value] = key

def pop(self, key):
value = dict.pop(self, key)
self.d1.pop(value)
return value


def init_record(msg):
"""return an empty TaskRecord dict, with all keys initialized with None."""
header = msg['header']
Expand Down Expand Up @@ -484,6 +457,8 @@ def save_queue_result(self, idents, msg):
}
if MongoDB is not None and isinstance(self.db, MongoDB):
result['result_buffers'] = map(Binary, msg['buffers'])
else:
result['result_buffers'] = msg['buffers']
self.db.update_record(msg_id, result)
else:
logger.debug("queue:: unknown msg finished %s"%msg_id)
Expand Down Expand Up @@ -552,6 +527,8 @@ def save_task_result(self, idents, msg):
}
if MongoDB is not None and isinstance(self.db, MongoDB):
result['result_buffers'] = map(Binary, msg['buffers'])
else:
result['result_buffers'] = msg['buffers']
self.db.update_record(msg_id, result)

else:
Expand Down Expand Up @@ -831,30 +808,38 @@ def resubmit_task(self, client_id, msg, buffers):
def get_results(self, client_id, msg):
"""Get the result of 1 or more messages."""
content = msg['content']
msg_ids = set(content['msg_ids'])
msg_ids = sorted(set(content['msg_ids']))
statusonly = content.get('status_only', False)
pending = []
completed = []
content = dict(status='ok')
content['pending'] = pending
content['completed'] = completed
buffers = []
if not statusonly:
content['results'] = {}
records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
for msg_id in msg_ids:
if msg_id in self.pending:
pending.append(msg_id)
elif msg_id in self.all_completed:
completed.append(msg_id)
if not statusonly:
content[msg_id] = records[msg_id]['result_content']
rec = records[msg_id]
content[msg_id] = { 'result_content': rec['result_content'],
'header': rec['header'],
'result_header' : rec['result_header'],
}
buffers.extend(map(str, rec['result_buffers']))
else:
try:
raise KeyError('No such message: '+msg_id)
except:
content = wrap_exception()
break
self.session.send(self.clientele, "result_reply", content=content,
parent=msg, ident=client_id)
parent=msg, ident=client_id,
buffers=buffers)


#-------------------------------------------------------------------------
Expand Down
10 changes: 7 additions & 3 deletions IPython/zmq/parallel/mongodb.py
Expand Up @@ -35,7 +35,7 @@ def get_record(self, msg_id):
def update_record(self, msg_id, rec):
"""Update the data in an existing record."""
obj_id = self._table[msg_id]
self._records.update({'_id':obj_id}, rec)
self._records.update({'_id':obj_id}, {'$set': rec})

def drop_matching_records(self, check):
"""Remove a record from the DB."""
Expand All @@ -50,7 +50,11 @@ def find_records(self, check, id_only=False):
"""Find records matching a query dict."""
matches = list(self._records.find(check))
if id_only:
matches = [ rec['msg_id'] for rec in matches ]
return matches
return [ rec['msg_id'] for rec in matches ]
else:
data = {}
for rec in matches:
data[rec['msg_id']] = rec
return data


4 changes: 2 additions & 2 deletions IPython/zmq/parallel/remotefunction.py
Expand Up @@ -126,10 +126,10 @@ def __call__(self, *sequences):
f=self.func
mid = self.client.apply(f, args=args, block=False,
bound=self.bound,
targets=engineid)._msg_ids[0]
targets=engineid).msg_ids[0]
msg_ids.append(mid)

r = AsyncMapResult(self.client, msg_ids, self.mapObject)
r = AsyncMapResult(self.client, msg_ids, self.mapObject, fname=self.func.__name__)
if self.block:
r.wait()
return r.result
Expand Down
6 changes: 3 additions & 3 deletions IPython/zmq/parallel/streamsession.py
Expand Up @@ -208,19 +208,19 @@ def unserialize_object(bufs):
for s in sobj:
if s.data is None:
s.data = bufs.pop(0)
return uncanSequence(map(unserialize, sobj))
return uncanSequence(map(unserialize, sobj)), bufs
elif isinstance(sobj, dict):
newobj = {}
for k in sorted(sobj.iterkeys()):
s = sobj[k]
if s.data is None:
s.data = bufs.pop(0)
newobj[k] = uncan(unserialize(s))
return newobj
return newobj, bufs
else:
if sobj.data is None:
sobj.data = bufs.pop(0)
return uncan(unserialize(sobj))
return uncan(unserialize(sobj)), bufs

def pack_apply_message(f, args, kwargs, threshold=64e-6):
"""pack up a function, args, and kwargs to be sent over the wire
Expand Down
35 changes: 35 additions & 0 deletions IPython/zmq/parallel/util.py
@@ -0,0 +1,35 @@
"""some generic utilities"""

class ReverseDict(dict):
"""simple double-keyed subset of dict methods."""

def __init__(self, *args, **kwargs):
dict.__init__(self, *args, **kwargs)
self._reverse = dict()
for key, value in self.iteritems():
self._reverse[value] = key

def __getitem__(self, key):
try:
return dict.__getitem__(self, key)
except KeyError:
return self._reverse[key]

def __setitem__(self, key, value):
if key in self._reverse:
raise KeyError("Can't have key %r on both sides!"%key)
dict.__setitem__(self, key, value)
self._reverse[value] = key

def pop(self, key):
value = dict.pop(self, key)
self.d1.pop(value)
return value

def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default


2 changes: 2 additions & 0 deletions IPython/zmq/parallel/view.py
Expand Up @@ -183,6 +183,8 @@ def map(self, f, *sequences):
"""Parallel version of builtin `map`, using this view's engines."""
if isinstance(self.targets, int):
targets = [self.targets]
else:
targets = self.targets
pf = ParallelFunction(self.client, f, block=self.block,
bound=True, targets=targets)
return pf.map(*sequences)
Expand Down

0 comments on commit 6f48516

Please sign in to comment.