diff --git a/traitlets/tests/test_traitlets.py b/traitlets/tests/test_traitlets.py index 0e6a98d8..d81ada8b 100644 --- a/traitlets/tests/test_traitlets.py +++ b/traitlets/tests/test_traitlets.py @@ -381,10 +381,10 @@ class A(HasTraits): a.on_trait_change(callback4, 'a') a.a = 100000 self.assertEqual(self.cb,('a',10000,100000,a)) - self.assertEqual(len(a._trait_notifiers['a']['change']), 1) + self.assertEqual(len(a._trait_notifiers), 1) a.on_trait_change(callback4, 'a', remove=True) - self.assertEqual(len(a._trait_notifiers['a']['change']), 0) + self.assertEqual(len(a._trait_notifiers), 0) def test_notify_only_once(self): @@ -567,10 +567,10 @@ class A(HasTraits): a.a = 100 change = change_dict('a', 10, 100, a, 'change') self.assertEqual(self.cb, change) - self.assertEqual(len(a._trait_notifiers['a']['change']), 1) + self.assertEqual(len(a._trait_notifiers), 1) a.unobserve(callback1, 'a') - self.assertEqual(len(a._trait_notifiers['a']['change']), 0) + self.assertEqual(len(a._trait_notifiers), 0) def test_notify_only_once(self): diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index a1901876..f8d0a060 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -570,47 +570,6 @@ def default_value_repr(self): # The HasTraits implementation #----------------------------------------------------------------------------- -class _CallbackWrapper(object): - """An object adapting a on_trait_change callback into an observe callback. - - The comparison operator __eq__ is implemented to enable removal of wrapped - callbacks. - """ - - def __init__(self, cb): - self.cb = cb - # Bound methods have an additional 'self' argument. - offset = -1 if isinstance(self.cb, types.MethodType) else 0 - self.nargs = len(getargspec(cb)[0]) + offset - if (self.nargs > 4): - raise TraitError('a trait changed callback must have 0-4 arguments.') - - def __eq__(self, other): - # The wrapper is equal to the wrapped element - if isinstance(other, _CallbackWrapper): - return self.cb == other.cb - else: - return self.cb == other - - def __call__(self, change): - # The wrapper is callable - if self.nargs == 0: - self.cb() - elif self.nargs == 1: - self.cb(change['name']) - elif self.nargs == 2: - self.cb(change['name'], change['new']) - elif self.nargs == 3: - self.cb(change['name'], change['old'], change['new']) - elif self.nargs == 4: - self.cb(change['name'], change['old'], change['new'], change['owner']) - -def _callback_wrapper(cb): - if isinstance(cb, _CallbackWrapper): - return cb - else: - return _CallbackWrapper(cb) - class MetaHasTraits(type): """Deprecated, use MetaHasDescriptors""" @@ -737,9 +696,15 @@ def __init__(self, names, tags, type): self.type = type def instance_init(self, inst): - meth = types.MethodType(self.func, inst) tagged = list(inst.trait_names(**self.tags)) - inst.observe(self, self.names+tagged, type=self.type) + self.names = list(set(self.names + tagged)) + inst._trait_notifiers.append(self) + + def __eq__(self, other): + if self.name is None: + return self.func == other + else: + return self == other class ValidateHandler(EventHandler): @@ -752,6 +717,48 @@ def instance_init(self, inst): inst._register_validator(self, self.names) +class _CallbackWrapper(ObserveHandler): + """An object adapting a on_trait_change callback into an observe callback. + + The comparison operator __eq__ is implemented to enable removal of wrapped + callbacks. + """ + + def _init_call(self, func): + self.func = func + offset = 1 if isinstance(func, types.MethodType) else 0 + self.nargs = len(getargspec(func)[0]) - offset + if (self.nargs > 4): + raise TraitError('a trait changed callback must have 0-4 arguments.') + return self + + def _func_call(self, change): + if self.nargs == 0: + self.func() + elif self.nargs == 1: + self.func(change['name']) + elif self.nargs == 2: + self.func(change['name'], change['new']) + elif self.nargs == 3: + self.func(change['name'], change['old'], change['new']) + elif self.nargs == 4: + self.func(change['name'], change['old'], change['new'], change['owner']) + + def __call__(self, *args, **kwargs): + if hasattr(self, 'func'): + return self._func_call(*args, **kwargs) + else: + return self._init_call(*args, **kwargs) + + def __eq__(self, other): + # The wrapper is equal to the wrapped element + if isinstance(other, _CallbackWrapper): + return self.func == other.func + else: + return self.func == other + + + class HasDescriptors(py3compat.with_metaclass(MetaHasDescriptors, object)): """The base class for all classes that have descriptors. """ @@ -785,7 +792,7 @@ class HasTraits(HasDescriptors): def install_descriptors(self, cls): self._trait_values = {} - self._trait_notifiers = {} + self._trait_notifiers = [] self._trait_validators = {} super(HasTraits, self).install_descriptors(cls) @@ -803,7 +810,7 @@ def __getstate__(self): # event handlers stored on an instance are # expected to be reinstantiated during a # recall of instance_init during __setstate__ - d['_trait_notifiers'] = {} + d['_trait_notifiers'] = [] d['_trait_validators'] = {} return d @@ -905,11 +912,10 @@ def _notify_trait(self, name, old_value, new_value): }) def _notify_change(self, name, type, change): - callables = [] - callables.extend(self._trait_notifiers.get(name, {}).get(type, [])) - callables.extend(self._trait_notifiers.get(name, {}).get(All, [])) - callables.extend(self._trait_notifiers.get(All, {}).get(type, [])) - callables.extend(self._trait_notifiers.get(All, {}).get(All, [])) + callables = list() + for n in self._trait_notifiers: + if name in n.names and type==n.type: + callables.append(n) # Now static ones magic_name = '_%s_changed' % name @@ -921,18 +927,22 @@ def _notify_change(self, name, type, change): cb = getattr(self, '_%s_changed' % name) # Only append the magic method if it was not manually registered if cb not in callables: - callables.append(_callback_wrapper(cb)) + wrap = _CallbackWrapper(name, tags={}, type='change') + callables.append(wrap(cb)) # Call them all now # Traits catches and logs errors here. I allow them to raise for c in callables: - # Bound methods have an additional 'self' argument. - if isinstance(c, _CallbackWrapper): - # _CallbackWrappers are not compatible with getargspec and have one argument - c = c.__call__ - elif isinstance(c, EventHandler): + # Bound methods have an additional 'self' argument. + print(c) + if c.name is not None: c = getattr(self, c.name) + elif isinstance(c, _CallbackWrapper): + c = c._func_call + else: + c = c.func + print(c) offset = 1 if isinstance(c, types.MethodType) else 0 nargs = len(getargspec(c)[0]) - offset @@ -945,28 +955,6 @@ def _notify_change(self, name, type, change): raise TraitError('an observe change callback ' 'must have 0-1 arguments.') - def _add_notifiers(self, handler, name, type): - if name not in self._trait_notifiers: - nlist = [] - self._trait_notifiers[name] = {type: nlist} - else: - if type not in self._trait_notifiers[name]: - nlist = [] - self._trait_notifiers[name][type] = nlist - else: - nlist = self._trait_notifiers[name][type] - if handler not in nlist: - nlist.append(handler) - - def _remove_notifiers(self, handler, name, type): - try: - if handler is None: - del self._trait_notifiers[name][type] - else: - self._trait_notifiers[name][type].remove(handler) - except KeyError: - pass - def on_trait_change(self, handler=None, name=None, remove=False): """DEPRECATED: Setup a handler to be called when a trait changes. @@ -1000,9 +988,11 @@ def on_trait_change(self, handler=None, name=None, remove=False): if name is None: name = All if remove: - self.unobserve(_callback_wrapper(handler), names=name) + self.unobserve(handler, names=name) else: - self.observe(_callback_wrapper(handler), names=name) + handler = _CallbackWrapper(name, tags={}, type='change')(handler) + handler.instance_init(self) + return handler def observe(self, handler, names=All, tags=None, type='change'): """Setup a handler to be called when a trait changes. @@ -1033,11 +1023,9 @@ def observe(self, handler, names=All, tags=None, type='change'): The type of notification to filter by. If equal to All, then all notifications are passed to the observe handler. """ - names = parse_notifier_name(names) - if tags: - names.extend(self.trait_names(**tags)) - for n in names: - self._add_notifiers(handler, n, type) + handler = ObserveHandler(names, tags=(tags or {}), type=type)(handler) + handler.instance_init(self) + return handler def unobserve(self, handler, names=All, tags=None, type='change'): """Remove a trait change handler. @@ -1059,11 +1047,21 @@ def unobserve(self, handler, names=All, tags=None, type='change'): The type of notification to filter by. If All, the specified handler is uninstalled from the list of notifiers corresponding to all types. """ + if tags is not None: + names = list(names) + list(self.trait_names(**tags)) names = parse_notifier_name(names) - if tags: - names.extend(self.trait_names(**tags)) - for n in names: - self._remove_notifiers(handler, n, type) + + if handler is None: + for i in range(len(self._trait_notifiers)): + notifier = self._trait_notifiers[i] + if notifier.type == type: + for j in range(len(notifier.names)): + if notifier.names[j] in names: + notifier.names.pop(j) + else: + i = self._trait_notifiers.index(handler) + return self._trait_notifiers.pop(i) + def unobserve_all(self, name=All): """Remove trait change handlers of any type for the specified name. @@ -1222,15 +1220,12 @@ def add_traits(self, **traits): """Dynamically add trait attributes to the HasTraits instance.""" self.__class__ = type(self.__class__.__name__, (self.__class__,), traits) - - observers = [memb[1] for memb in getmembers(self.__class__) - if isinstance(memb[1], ObserveHandler)] for name, trait in traits.items(): trait.instance_init(self) - for o in observers: + for o in self._trait_notifiers: if trait.name in self.traits(**o.tags): - self.observe(o, trait.name, type=o.type) + o.names.append(trait.name) def set_trait(self, name, value): """Forcibly sets trait attribute, including read-only attributes."""