Permalink
Browse files

Improvements to dependency handling

Specifically:
  * add 'success_only' switch to Dependencies
  * Scheduler handles some cases where Dependencies are impossible to meet.
  • Loading branch information...
1 parent 9f1a03a commit d51586b9eda97ab57c65fa0be2aa7756f173ec15 @minrk minrk committed Feb 8, 2011
@@ -16,17 +16,23 @@
#-------------------------------------------------------------------------------
from types import FunctionType
+import copy
-# contents of codeutil should either be in here, or codeutil belongs in IPython/util
from IPython.zmq.parallel.dependency import dependent
+
import codeutil
+#-------------------------------------------------------------------------------
+# Classes
+#-------------------------------------------------------------------------------
+
+
class CannedObject(object):
def __init__(self, obj, keys=[]):
self.keys = keys
- self.obj = obj
+ self.obj = copy.copy(obj)
for key in keys:
- setattr(obj, key, can(getattr(obj, key)))
+ setattr(self.obj, key, can(getattr(obj, key)))
def getObject(self, g=None):
@@ -43,6 +49,7 @@ class CannedFunction(CannedObject):
def __init__(self, f):
self._checkType(f)
self.code = f.func_code
+ self.__name__ = f.__name__
def _checkType(self, obj):
assert isinstance(obj, FunctionType), "Not a function type"
@@ -53,6 +60,11 @@ def getFunction(self, g=None):
newFunc = FunctionType(self.code, g)
return newFunc
+#-------------------------------------------------------------------------------
+# Functions
+#-------------------------------------------------------------------------------
+
+
def can(obj):
if isinstance(obj, FunctionType):
return CannedFunction(obj)
@@ -36,6 +36,7 @@ def __init__(self, client, msg_ids, fname=''):
self._fname=fname
self._ready = False
self._success = None
+ self._flatten_result = len(msg_ids) == 1
def __repr__(self):
if self._ready:
@@ -49,7 +50,7 @@ def _reconstruct_result(self, res):
Override me in subclasses for turning a list of results
into the expected form.
"""
- if len(self.msg_ids) == 1:
+ if self._flatten_result:
return res[0]
else:
return res
@@ -115,7 +116,7 @@ def successful(self):
def get_dict(self, timeout=-1):
"""Get the results as a dict, keyed by engine_id."""
results = self.get(timeout)
- engine_ids = [md['engine_id'] for md in self._metadata ]
+ engine_ids = [ md['engine_id'] for md in self._metadata ]
bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
maxcount = bycount.count(bycount[-1])
if maxcount > 1:
@@ -130,11 +131,17 @@ def result(self):
"""result property."""
return self._result
+ # abbreviated alias:
+ r = result
+
@property
@check_ready
def metadata(self):
"""metadata property."""
- return self._metadata
+ if self._flatten_result:
+ return self._metadata[0]
+ else:
+ return self._metadata
@property
def result_dict(self):
@@ -157,7 +164,11 @@ def __getitem__(self, key):
elif isinstance(key, slice):
return error.collect_exceptions(self._result[key], self._fname)
elif isinstance(key, basestring):
- return [ md[key] for md in self._metadata ]
+ values = [ md[key] for md in self._metadata ]
+ if self._flatten_result:
+ return values[0]
+ else:
+ return values
else:
raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
@@ -177,8 +188,9 @@ class AsyncMapResult(AsyncResult):
"""
def __init__(self, client, msg_ids, mapObject, fname=''):
- self._mapObject = mapObject
AsyncResult.__init__(self, client, msg_ids, fname=fname)
+ self._mapObject = mapObject
+ self._flatten_result = False
def _reconstruct_result(self, res):
"""Perform the gather on the actual results."""
@@ -91,7 +91,13 @@ def defaultblock(f, self, *args, **kwargs):
#--------------------------------------------------------------------------
class Metadata(dict):
- """Subclass of dict for initializing metadata values."""
+ """Subclass of dict for initializing metadata values.
+
+ Attribute access works on keys.
+
+ These objects have a strict set of keys - errors will raise if you try
+ to add new keys.
+ """
def __init__(self, *args, **kwargs):
dict.__init__(self)
md = {'msg_id' : None,
@@ -113,7 +119,27 @@ def __init__(self, *args, **kwargs):
}
self.update(md)
self.update(dict(*args, **kwargs))
+
+ def __getattr__(self, key):
+ """getattr aliased to getitem"""
+ if key in self.iterkeys():
+ return self[key]
+ else:
+ raise AttributeError(key)
+ def __setattr__(self, key, value):
+ """setattr aliased to setitem, with strict"""
+ if key in self.iterkeys():
+ self[key] = value
+ else:
+ raise AttributeError(key)
+
+ def __setitem__(self, key, value):
+ """strict static key enforcement"""
+ if key in self.iterkeys():
+ dict.__setitem__(self, key, value)
+ else:
+ raise KeyError(key)
class Client(object):
@@ -372,16 +398,22 @@ def _unregister_engine(self, msg):
def _extract_metadata(self, header, parent, content):
md = {'msg_id' : parent['msg_id'],
- 'submitted' : datetime.strptime(parent['date'], ss.ISO8601),
- 'started' : datetime.strptime(header['started'], ss.ISO8601),
- 'completed' : datetime.strptime(header['date'], ss.ISO8601),
'received' : datetime.now(),
- 'engine_uuid' : header['engine'],
- 'engine_id' : self._engines.get(header['engine'], None),
+ 'engine_uuid' : header.get('engine', None),
'follow' : parent['follow'],
'after' : parent['after'],
'status' : content['status'],
}
+
+ if md['engine_uuid'] is not None:
+ md['engine_id'] = self._engines.get(md['engine_uuid'], None)
+
+ if 'date' in parent:
+ md['submitted'] = datetime.strptime(parent['date'], ss.ISO8601)
+ if 'started' in header:
+ md['started'] = datetime.strptime(header['started'], ss.ISO8601)
+ if 'date' in header:
+ md['completed'] = datetime.strptime(header['date'], ss.ISO8601)
return md
def _handle_execute_reply(self, msg):
@@ -393,7 +425,10 @@ def _handle_execute_reply(self, msg):
parent = msg['parent_header']
msg_id = parent['msg_id']
if msg_id not in self.outstanding:
- print("got unknown result: %s"%msg_id)
+ if msg_id in self.history:
+ print ("got stale result: %s"%msg_id)
+ else:
+ print ("got unknown result: %s"%msg_id)
else:
self.outstanding.remove(msg_id)
self.results[msg_id] = ss.unwrap_exception(msg['content'])
@@ -403,7 +438,12 @@ def _handle_apply_reply(self, msg):
parent = msg['parent_header']
msg_id = parent['msg_id']
if msg_id not in self.outstanding:
- print ("got unknown result: %s"%msg_id)
+ if msg_id in self.history:
+ print ("got stale result: %s"%msg_id)
+ print self.results[msg_id]
+ print msg
+ else:
+ print ("got unknown result: %s"%msg_id)
else:
self.outstanding.remove(msg_id)
content = msg['content']
@@ -424,9 +464,10 @@ def _handle_apply_reply(self, msg):
pass
else:
e = ss.unwrap_exception(content)
- e_uuid = e.engine_info['engineid']
- eid = self._engines[e_uuid]
- e.engine_info['engineid'] = eid
+ if e.engine_info:
+ e_uuid = e.engine_info['engineid']
+ eid = self._engines[e_uuid]
+ e.engine_info['engineid'] = eid
self.results[msg_id] = e
def _flush_notifications(self):
@@ -811,6 +852,8 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
elif after is None:
after = []
if isinstance(follow, Dependency):
+ # if len(follow) > 1 and follow.mode == 'all':
+ # warn("complex follow-dependencies are not rigorously tested for reachability", UserWarning)
follow = follow.as_dict()
elif isinstance(follow, AsyncResult):
follow=follow.msg_ids
@@ -827,7 +870,6 @@ def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
after=None, follow=None):
"""The underlying method for applying functions in a load balanced
manner, via the task queue."""
-
subheader = dict(after=after, follow=follow)
bufs = ss.pack_apply_message(f,args,kwargs)
content = dict(bound=bound)
@@ -1,16 +1,15 @@
"""Dependency utilities"""
from IPython.external.decorator import decorator
+from error import UnmetDependency
+
# flags
ALL = 1 << 0
ANY = 1 << 1
HERE = 1 << 2
ANYWHERE = 1 << 3
-class UnmetDependency(Exception):
- pass
-
class depend(object):
"""Dependency decorator, for use with tasks."""
@@ -30,7 +29,7 @@ class dependent(object):
def __init__(self, f, df, *dargs, **dkwargs):
self.f = f
- self.func_name = self.f.func_name
+ self.func_name = getattr(f, '__name__', 'f')
self.df = df
self.dargs = dargs
self.dkwargs = dkwargs
@@ -39,6 +38,10 @@ def __call__(self, *args, **kwargs):
if self.df(*self.dargs, **self.dkwargs) is False:
raise UnmetDependency()
return self.f(*args, **kwargs)
+
+ @property
+ def __name__(self):
+ return self.func_name
def _require(*names):
for name in names:
@@ -57,18 +60,23 @@ class Dependency(set):
Subclassed from set()."""
mode='all'
+ success_only=True
- def __init__(self, dependencies=[], mode='all'):
+ def __init__(self, dependencies=[], mode='all', success_only=True):
if isinstance(dependencies, dict):
# load from dict
- dependencies = dependencies.get('dependencies', [])
mode = dependencies.get('mode', mode)
+ success_only = dependencies.get('success_only', success_only)
+ dependencies = dependencies.get('dependencies', [])
set.__init__(self, dependencies)
self.mode = mode.lower()
+ self.success_only=success_only
if self.mode not in ('any', 'all'):
raise NotImplementedError("Only any|all supported, not %r"%mode)
- def check(self, completed):
+ def check(self, completed, failed=None):
+ if failed is not None and not self.success_only:
+ completed = completed.union(failed)
if len(self) == 0:
return True
if self.mode == 'all':
@@ -78,13 +86,26 @@ def check(self, completed):
else:
raise NotImplementedError("Only any|all supported, not %r"%mode)
+ def unreachable(self, failed):
+ if len(self) == 0 or len(failed) == 0 or not self.success_only:
+ return False
+ print self, self.success_only, self.mode, failed
+ if self.mode == 'all':
+ return not self.isdisjoint(failed)
+ elif self.mode == 'any':
+ return self.issubset(failed)
+ else:
+ raise NotImplementedError("Only any|all supported, not %r"%mode)
+
+
def as_dict(self):
"""Represent this dependency as a dict. For json compatibility."""
return dict(
dependencies=list(self),
- mode=self.mode
+ mode=self.mode,
+ success_only=self.success_only,
)
-__all__ = ['UnmetDependency', 'depend', 'require', 'Dependency']
+__all__ = ['depend', 'require', 'Dependency']
@@ -148,6 +148,12 @@ class FileTimeoutError(KernelError):
class TimeoutError(KernelError):
pass
+class UnmetDependency(KernelError):
+ pass
+
+class ImpossibleDependency(UnmetDependency):
+ pass
+
class RemoteError(KernelError):
"""Error raised elsewhere"""
ename=None
Oops, something went wrong.

0 comments on commit d51586b

Please sign in to comment.