Skip to content

Commit

Permalink
WIP: wrap all notifiers in EventHandlers
Browse files Browse the repository at this point in the history
  • Loading branch information
rmorshea committed Oct 22, 2015
1 parent 1ae7d72 commit 48679e4
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 117 deletions.
8 changes: 4 additions & 4 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,10 @@ class A(HasTraits):
a.on_trait_change(callback4, 'a')
a.a = 100000
self.assertEqual(self.cb,('a',10000,100000,a))
self.assertEqual(len(a._trait_notifiers['a']['change']), 1)
self.assertEqual(len(a._trait_notifiers), 1)
a.on_trait_change(callback4, 'a', remove=True)

self.assertEqual(len(a._trait_notifiers['a']['change']), 0)
self.assertEqual(len(a._trait_notifiers), 0)

def test_notify_only_once(self):

Expand Down Expand Up @@ -567,10 +567,10 @@ class A(HasTraits):
a.a = 100
change = change_dict('a', 10, 100, a, 'change')
self.assertEqual(self.cb, change)
self.assertEqual(len(a._trait_notifiers['a']['change']), 1)
self.assertEqual(len(a._trait_notifiers), 1)
a.unobserve(callback1, 'a')

self.assertEqual(len(a._trait_notifiers['a']['change']), 0)
self.assertEqual(len(a._trait_notifiers), 0)

def test_notify_only_once(self):

Expand Down
210 changes: 97 additions & 113 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,14 +500,17 @@ def _validate(self, obj, value):
return value

def _cross_validate(self, obj, value):
if self.name in obj._trait_validators:
proposal = {'trait': self, 'value': value, 'owner': obj}
value = obj._trait_validators[self.name](obj, proposal)
elif hasattr(obj, '_%s_validate' % self.name):
for v in obj._trait_validators:
if self.name in v.names:
proposal = {'trait': self, 'value': value, 'owner': obj}
return v(obj, proposal)

if hasattr(obj, '_%s_validate' % self.name):
warn("_[traitname]_validate handlers are deprecated: use validate"
" decorator instead", DeprecationWarning, stacklevel=2)
cross_validate = getattr(obj, '_%s_validate' % self.name)
value = cross_validate(value, self)
return cross_validate(value, self)

return value

def __or__(self, other):
Expand Down Expand Up @@ -570,47 +573,6 @@ def default_value_repr(self):
# The HasTraits implementation
#-----------------------------------------------------------------------------

class _CallbackWrapper(object):
"""An object adapting a on_trait_change callback into an observe callback.
The comparison operator __eq__ is implemented to enable removal of wrapped
callbacks.
"""

def __init__(self, cb):
self.cb = cb
# Bound methods have an additional 'self' argument.
offset = -1 if isinstance(self.cb, types.MethodType) else 0
self.nargs = len(getargspec(cb)[0]) + offset
if (self.nargs > 4):
raise TraitError('a trait changed callback must have 0-4 arguments.')

def __eq__(self, other):
# The wrapper is equal to the wrapped element
if isinstance(other, _CallbackWrapper):
return self.cb == other.cb
else:
return self.cb == other

def __call__(self, change):
# The wrapper is callable
if self.nargs == 0:
self.cb()
elif self.nargs == 1:
self.cb(change['name'])
elif self.nargs == 2:
self.cb(change['name'], change['new'])
elif self.nargs == 3:
self.cb(change['name'], change['old'], change['new'])
elif self.nargs == 4:
self.cb(change['name'], change['old'], change['new'], change['owner'])

def _callback_wrapper(cb):
if isinstance(cb, _CallbackWrapper):
return cb
else:
return _CallbackWrapper(cb)


class MetaHasTraits(type):
"""Deprecated, use MetaHasDescriptors"""
Expand Down Expand Up @@ -737,9 +699,50 @@ def __init__(self, names, tags, type):
self.type = type

def instance_init(self, inst):
meth = types.MethodType(self.func, inst)
tagged = list(inst.trait_names(**self.tags))
inst.observe(self, self.names+tagged, type=self.type)
self.names = list(set(self.names + tagged))
inst._trait_notifiers.append(self)

def __eq__(self, other):
# The wrapper is equal to the wrapped element
if isinstance(other, self.__class__):
return self.func == other.func
else:
return self.func == other


class _CallbackWrapper(ObserveHandler):
"""An object adapting a on_trait_change callback into an observe callback.
The comparison operator __eq__ is implemented to enable removal of wrapped
callbacks.
"""

def _init_call(self, func):
self.func = func
offset = 1 if isinstance(func, types.MethodType) else 0
self.nargs = len(getargspec(func)[0]) - offset
if (self.nargs > 4):
raise TraitError('a trait changed callback must have 0-4 arguments.')
return self

def _func_call(self, change):
if self.nargs == 0:
self.func()
elif self.nargs == 1:
self.func(change['name'])
elif self.nargs == 2:
self.func(change['name'], change['new'])
elif self.nargs == 3:
self.func(change['name'], change['old'], change['new'])
elif self.nargs == 4:
self.func(change['name'], change['old'], change['new'], change['owner'])

def __call__(self, *args, **kwargs):
if hasattr(self, 'func'):
return self._func_call(*args, **kwargs)
else:
return self._init_call(*args, **kwargs)


class ValidateHandler(EventHandler):
Expand All @@ -750,6 +753,7 @@ def __init__(self, names):
def instance_init(self, inst):
meth = types.MethodType(self.func, inst)
inst._register_validator(self, self.names)



class HasDescriptors(py3compat.with_metaclass(MetaHasDescriptors, object)):
Expand Down Expand Up @@ -785,8 +789,8 @@ class HasTraits(HasDescriptors):

def install_descriptors(self, cls):
self._trait_values = {}
self._trait_notifiers = {}
self._trait_validators = {}
self._trait_notifiers = []
self._trait_validators = []
super(HasTraits, self).install_descriptors(cls)

def __init__(self, *args, **kw):
Expand All @@ -803,8 +807,8 @@ def __getstate__(self):
# event handlers stored on an instance are
# expected to be reinstantiated during a
# recall of instance_init during __setstate__
d['_trait_notifiers'] = {}
d['_trait_validators'] = {}
d['_trait_notifiers'] = []
d['_trait_validators'] = []
return d

def __setstate__(self, state):
Expand Down Expand Up @@ -905,11 +909,10 @@ def _notify_trait(self, name, old_value, new_value):
})

def _notify_change(self, name, type, change):
callables = []
callables.extend(self._trait_notifiers.get(name, {}).get(type, []))
callables.extend(self._trait_notifiers.get(name, {}).get(All, []))
callables.extend(self._trait_notifiers.get(All, {}).get(type, []))
callables.extend(self._trait_notifiers.get(All, {}).get(All, []))
callables = list()
for n in self._trait_notifiers:
if name in n.names and type==n.type:
callables.append(n)

# Now static ones
magic_name = '_%s_changed' % name
Expand All @@ -921,18 +924,20 @@ def _notify_change(self, name, type, change):
cb = getattr(self, '_%s_changed' % name)
# Only append the magic method if it was not manually registered
if cb not in callables:
callables.append(_callback_wrapper(cb))
wrap = _CallbackWrapper(name, tags={}, type='change')
callables.append(wrap(cb))

# Call them all now
# Traits catches and logs errors here. I allow them to raise
for c in callables:
# Bound methods have an additional 'self' argument.

if isinstance(c, _CallbackWrapper):
# _CallbackWrappers are not compatible with getargspec and have one argument
c = c.__call__
elif isinstance(c, EventHandler):
# Bound methods have an additional 'self' argument.
if c.name is not None:
c = getattr(self, c.name)
elif isinstance(c, _CallbackWrapper):
c = c._func_call
else:
c = c.func

offset = 1 if isinstance(c, types.MethodType) else 0
nargs = len(getargspec(c)[0]) - offset
Expand All @@ -945,28 +950,6 @@ def _notify_change(self, name, type, change):
raise TraitError('an observe change callback '
'must have 0-1 arguments.')

def _add_notifiers(self, handler, name, type):
if name not in self._trait_notifiers:
nlist = []
self._trait_notifiers[name] = {type: nlist}
else:
if type not in self._trait_notifiers[name]:
nlist = []
self._trait_notifiers[name][type] = nlist
else:
nlist = self._trait_notifiers[name][type]
if handler not in nlist:
nlist.append(handler)

def _remove_notifiers(self, handler, name, type):
try:
if handler is None:
del self._trait_notifiers[name][type]
else:
self._trait_notifiers[name][type].remove(handler)
except KeyError:
pass

def on_trait_change(self, handler=None, name=None, remove=False):
"""DEPRECATED: Setup a handler to be called when a trait changes.
Expand Down Expand Up @@ -1000,9 +983,11 @@ def on_trait_change(self, handler=None, name=None, remove=False):
if name is None:
name = All
if remove:
self.unobserve(_callback_wrapper(handler), names=name)
self.unobserve(handler, names=name)
else:
self.observe(_callback_wrapper(handler), names=name)
handler = _CallbackWrapper(name, tags={}, type='change')(handler)
handler.instance_init(self)
return handler

def observe(self, handler, names=All, tags=None, type='change'):
"""Setup a handler to be called when a trait changes.
Expand Down Expand Up @@ -1033,11 +1018,9 @@ def observe(self, handler, names=All, tags=None, type='change'):
The type of notification to filter by. If equal to All, then all
notifications are passed to the observe handler.
"""
names = parse_notifier_name(names)
if tags:
names.extend(self.trait_names(**tags))
for n in names:
self._add_notifiers(handler, n, type)
handler = ObserveHandler(names, tags=(tags or {}), type=type)(handler)
handler.instance_init(self)
return handler

def unobserve(self, handler, names=All, tags=None, type='change'):
"""Remove a trait change handler.
Expand All @@ -1059,11 +1042,21 @@ def unobserve(self, handler, names=All, tags=None, type='change'):
The type of notification to filter by. If All, the specified handler
is uninstalled from the list of notifiers corresponding to all types.
"""
if tags is not None:
names = list(names) + list(self.trait_names(**tags))
names = parse_notifier_name(names)
if tags:
names.extend(self.trait_names(**tags))
for n in names:
self._remove_notifiers(handler, n, type)

if handler is None:
for i in range(len(self._trait_notifiers)):
notifier = self._trait_notifiers[i]
if notifier.type == type:
for j in range(len(notifier.names)):
if notifier.names[j] in names:
notifier.names.pop(j)
else:
i = self._trait_notifiers.index(handler)
return self._trait_notifiers.pop(i)


def unobserve_all(self, name=All):
"""Remove trait change handlers of any type for the specified name.
Expand Down Expand Up @@ -1097,18 +1090,12 @@ def _register_validator(self, handler, names):
The names of the traits that should be cross-validated
"""
for name in names:
if name in self._trait_validators:
raise TraitError("A cross-validator for the trait"
" '%s' already exists" % name)

magic_name = '_%s_validate' % name
if hasattr(self, magic_name):
class_value = getattr(self.__class__, magic_name)
if not isinstance(class_value, ValidateHandler):
warn("_[traitname]_validate handlers are deprecated: use validate"
" decorator instead", DeprecationWarning, stacklevel=2)
for name in names:
self._trait_validators[name] = handler
for v in self._trait_validators:
if name in v.names:
raise TraitError("A cross-validator for the trait"
" '%s' already exists" % name)

self._trait_validators.append(handler)

@classmethod
def class_trait_names(cls, **metadata):
Expand Down Expand Up @@ -1222,15 +1209,12 @@ def add_traits(self, **traits):
"""Dynamically add trait attributes to the HasTraits instance."""
self.__class__ = type(self.__class__.__name__, (self.__class__,),
traits)

observers = [memb[1] for memb in getmembers(self.__class__)
if isinstance(memb[1], ObserveHandler)]

for name, trait in traits.items():
trait.instance_init(self)
for o in observers:
for o in self._trait_notifiers:
if trait.name in self.traits(**o.tags):
self.observe(o, trait.name, type=o.type)
o.names.append(trait.name)

def set_trait(self, name, value):
"""Forcibly sets trait attribute, including read-only attributes."""
Expand Down

0 comments on commit 48679e4

Please sign in to comment.