Permalink
Browse files

improved client.get_results() behavior

  • Loading branch information...
minrk committed Jan 29, 2011
1 parent a49b646 commit 6f48516f6a0054fb1f58dfa4bf424af7d3643e8b
@@ -40,7 +40,7 @@ def _reconstruct_result(self, res):
Override me in subclasses for turning a list of results
into the expected form.
"""
- if len(res) == 1:
+ if len(self.msg_ids) == 1:
return res[0]
else:
return res
@@ -14,6 +14,7 @@
import time
from getpass import getpass
from pprint import pprint
+from datetime import datetime
import zmq
from zmq.eventloop import ioloop, zmqstream
@@ -29,6 +30,7 @@
import map as Map
from asyncresult import AsyncResult, AsyncMapResult
from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
+from util import ReverseDict
#--------------------------------------------------------------------------
# helpers for implementing old MEC API via client.apply
@@ -83,6 +85,11 @@ def defaultblock(f, self, *args, **kwargs):
self.block = saveblock
return ret
+
+#--------------------------------------------------------------------------
+# Classes
+#--------------------------------------------------------------------------
+
class AbortedTask(object):
"""A basic wrapper object describing an aborted task."""
def __init__(self, msg_id):
@@ -233,10 +240,11 @@ def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, de
tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
else:
self._registration_socket.connect(addr)
- self._engines = {}
+ self._engines = ReverseDict()
self._ids = set()
self.outstanding=set()
self.results = {}
+ self.metadata = {}
self.history = []
self.debug = debug
self.session.debug = debug
@@ -342,9 +350,27 @@ def _unregister_engine(self, msg):
if eid in self._ids:
self._ids.remove(eid)
self._engines.pop(eid)
-
+ #
+ def _build_metadata(self, header, parent, content):
+ md = {'msg_id' : parent['msg_id'],
+ 'submitted' : datetime.strptime(parent['date'], ss.ISO8601),
+ 'started' : datetime.strptime(header['started'], ss.ISO8601),
+ 'completed' : datetime.strptime(header['date'], ss.ISO8601),
+ 'received' : datetime.now(),
+ 'engine_uuid' : header['engine'],
+ 'engine_id' : self._engines.get(header['engine'], None),
+ 'follow' : parent['follow'],
+ 'after' : parent['after'],
+ 'status' : content['status']
+ }
+ return md
+
def _handle_execute_reply(self, msg):
- """Save the reply to an execute_request into our results."""
+ """Save the reply to an execute_request into our results.
+
+ execute messages are never actually used. apply is used instead.
+ """
+
parent = msg['parent_header']
msg_id = parent['msg_id']
if msg_id not in self.outstanding:
@@ -362,8 +388,12 @@ def _handle_apply_reply(self, msg):
else:
self.outstanding.remove(msg_id)
content = msg['content']
+ header = msg['header']
+
+ self.metadata[msg_id] = self._build_metadata(header, parent, content)
+
if content['status'] == 'ok':
- self.results[msg_id] = ss.unserialize_object(msg['buffers'])
+ self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0]
elif content['status'] == 'aborted':
self.results[msg_id] = error.AbortedTask(msg_id)
elif content['status'] == 'resubmitted':
@@ -372,10 +402,8 @@ def _handle_apply_reply(self, msg):
else:
e = ss.unwrap_exception(content)
e_uuid = e.engine_info['engineid']
- for k,v in self._engines.iteritems():
- if v == e_uuid:
- e.engine_info['engineid'] = k
- break
+ eid = self._engines[e_uuid]
+ e.engine_info['engineid'] = eid
self.results[msg_id] = e
def _flush_notifications(self):
@@ -882,6 +910,13 @@ def get_results(self, msg_ids, status_only=False):
status_only : bool (default: False)
if False:
return the actual results
+
+ Returns
+ -------
+
+ results : dict
+ There will always be the keys 'pending' and 'completed', which will
+ be lists of msg_ids.
"""
if not isinstance(msg_ids, (list,tuple)):
msg_ids = [msg_ids]
@@ -895,11 +930,12 @@ def get_results(self, msg_ids, status_only=False):
completed = []
local_results = {}
- for msg_id in list(theids):
- if msg_id in self.results:
- completed.append(msg_id)
- local_results[msg_id] = self.results[msg_id]
- theids.remove(msg_id)
+ # temporarily disable local shortcut
+ # for msg_id in list(theids):
+ # if msg_id in self.results:
+ # completed.append(msg_id)
+ # local_results[msg_id] = self.results[msg_id]
+ # theids.remove(msg_id)
if theids: # some not locally cached
content = dict(msg_ids=theids, status_only=status_only)
@@ -911,16 +947,40 @@ def get_results(self, msg_ids, status_only=False):
content = msg['content']
if content['status'] != 'ok':
raise ss.unwrap_exception(content)
+ buffers = msg['buffers']
else:
content = dict(completed=[],pending=[])
- if not status_only:
- # load cached results into result:
- content['completed'].extend(completed)
- content.update(local_results)
- # update cache with results:
- for msg_id in msg_ids:
- if msg_id in content['completed']:
- self.results[msg_id] = content[msg_id]
+
+ content['completed'].extend(completed)
+
+ if status_only:
+ return content
+
+ failures = []
+ # load cached results into result:
+ content.update(local_results)
+ # update cache with results:
+ for msg_id in sorted(theids):
+ if msg_id in content['completed']:
+ rec = content[msg_id]
+ parent = rec['header']
+ header = rec['result_header']
+ rcontent = rec['result_content']
+ if isinstance(rcontent, str):
+ rcontent = self.session.unpack(rcontent)
+
+ self.metadata[msg_id] = self._build_metadata(header, parent, rcontent)
+
+ if rcontent['status'] == 'ok':
+ res,buffers = ss.unserialize_object(buffers)
+ else:
+ res = ss.unwrap_exception(rcontent)
+ failures.append(res)
+
+ self.results[msg_id] = res
+ content[msg_id] = res
+
+ error.collect_exceptions(failures, "get_results")
return content
@spinfirst
@@ -945,7 +1005,7 @@ def queue_status(self, targets=None, verbose=False):
status = content.pop('status')
if status != 'ok':
raise ss.unwrap_exception(content)
- return content
+ return ss.rekey(content)
@spinfirst
def purge_results(self, msg_ids=[], targets=[]):
@@ -47,33 +47,6 @@
def _passer(*args, **kwargs):
return
-class ReverseDict(dict):
- """simple double-keyed subset of dict methods."""
-
- def __init__(self, *args, **kwargs):
- dict.__init__(self, *args, **kwargs)
- self.reverse = dict()
- for key, value in self.iteritems():
- self.reverse[value] = key
-
- def __getitem__(self, key):
- try:
- return dict.__getitem__(self, key)
- except KeyError:
- return self.reverse[key]
-
- def __setitem__(self, key, value):
- if key in self.reverse:
- raise KeyError("Can't have key %r on both sides!"%key)
- dict.__setitem__(self, key, value)
- self.reverse[value] = key
-
- def pop(self, key):
- value = dict.pop(self, key)
- self.d1.pop(value)
- return value
-
-
def init_record(msg):
"""return an empty TaskRecord dict, with all keys initialized with None."""
header = msg['header']
@@ -484,6 +457,8 @@ def save_queue_result(self, idents, msg):
}
if MongoDB is not None and isinstance(self.db, MongoDB):
result['result_buffers'] = map(Binary, msg['buffers'])
+ else:
+ result['result_buffers'] = msg['buffers']
self.db.update_record(msg_id, result)
else:
logger.debug("queue:: unknown msg finished %s"%msg_id)
@@ -552,6 +527,8 @@ def save_task_result(self, idents, msg):
}
if MongoDB is not None and isinstance(self.db, MongoDB):
result['result_buffers'] = map(Binary, msg['buffers'])
+ else:
+ result['result_buffers'] = msg['buffers']
self.db.update_record(msg_id, result)
else:
@@ -831,30 +808,38 @@ def resubmit_task(self, client_id, msg, buffers):
def get_results(self, client_id, msg):
"""Get the result of 1 or more messages."""
content = msg['content']
- msg_ids = set(content['msg_ids'])
+ msg_ids = sorted(set(content['msg_ids']))
statusonly = content.get('status_only', False)
pending = []
completed = []
content = dict(status='ok')
content['pending'] = pending
content['completed'] = completed
+ buffers = []
if not statusonly:
+ content['results'] = {}
records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
for msg_id in msg_ids:
if msg_id in self.pending:
pending.append(msg_id)
elif msg_id in self.all_completed:
completed.append(msg_id)
if not statusonly:
- content[msg_id] = records[msg_id]['result_content']
+ rec = records[msg_id]
+ content[msg_id] = { 'result_content': rec['result_content'],
+ 'header': rec['header'],
+ 'result_header' : rec['result_header'],
+ }
+ buffers.extend(map(str, rec['result_buffers']))
else:
try:
raise KeyError('No such message: '+msg_id)
except:
content = wrap_exception()
break
self.session.send(self.clientele, "result_reply", content=content,
- parent=msg, ident=client_id)
+ parent=msg, ident=client_id,
+ buffers=buffers)
#-------------------------------------------------------------------------
@@ -35,7 +35,7 @@ def get_record(self, msg_id):
def update_record(self, msg_id, rec):
"""Update the data in an existing record."""
obj_id = self._table[msg_id]
- self._records.update({'_id':obj_id}, rec)
+ self._records.update({'_id':obj_id}, {'$set': rec})
def drop_matching_records(self, check):
"""Remove a record from the DB."""
@@ -50,7 +50,11 @@ def find_records(self, check, id_only=False):
"""Find records matching a query dict."""
matches = list(self._records.find(check))
if id_only:
- matches = [ rec['msg_id'] for rec in matches ]
- return matches
+ return [ rec['msg_id'] for rec in matches ]
+ else:
+ data = {}
+ for rec in matches:
+ data[rec['msg_id']] = rec
+ return data
@@ -126,10 +126,10 @@ def __call__(self, *sequences):
f=self.func
mid = self.client.apply(f, args=args, block=False,
bound=self.bound,
- targets=engineid)._msg_ids[0]
+ targets=engineid).msg_ids[0]
msg_ids.append(mid)
- r = AsyncMapResult(self.client, msg_ids, self.mapObject)
+ r = AsyncMapResult(self.client, msg_ids, self.mapObject, fname=self.func.__name__)
if self.block:
r.wait()
return r.result
@@ -208,19 +208,19 @@ def unserialize_object(bufs):
for s in sobj:
if s.data is None:
s.data = bufs.pop(0)
- return uncanSequence(map(unserialize, sobj))
+ return uncanSequence(map(unserialize, sobj)), bufs
elif isinstance(sobj, dict):
newobj = {}
for k in sorted(sobj.iterkeys()):
s = sobj[k]
if s.data is None:
s.data = bufs.pop(0)
newobj[k] = uncan(unserialize(s))
- return newobj
+ return newobj, bufs
else:
if sobj.data is None:
sobj.data = bufs.pop(0)
- return uncan(unserialize(sobj))
+ return uncan(unserialize(sobj)), bufs
def pack_apply_message(f, args, kwargs, threshold=64e-6):
"""pack up a function, args, and kwargs to be sent over the wire
@@ -0,0 +1,35 @@
+"""some generic utilities"""
+
+class ReverseDict(dict):
+ """simple double-keyed subset of dict methods."""
+
+ def __init__(self, *args, **kwargs):
+ dict.__init__(self, *args, **kwargs)
+ self._reverse = dict()
+ for key, value in self.iteritems():
+ self._reverse[value] = key
+
+ def __getitem__(self, key):
+ try:
+ return dict.__getitem__(self, key)
+ except KeyError:
+ return self._reverse[key]
+
+ def __setitem__(self, key, value):
+ if key in self._reverse:
+ raise KeyError("Can't have key %r on both sides!"%key)
+ dict.__setitem__(self, key, value)
+ self._reverse[value] = key
+
+ def pop(self, key):
+ value = dict.pop(self, key)
+ self.d1.pop(value)
+ return value
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+
@@ -183,6 +183,8 @@ def map(self, f, *sequences):
"""Parallel version of builtin `map`, using this view's engines."""
if isinstance(self.targets, int):
targets = [self.targets]
+ else:
+ targets = self.targets
pf = ParallelFunction(self.client, f, block=self.block,
bound=True, targets=targets)
return pf.map(*sequences)

0 comments on commit 6f48516

Please sign in to comment.