diff --git a/traitlets/tests/test_traitlets.py b/traitlets/tests/test_traitlets.py index cfe649db..1b8074d2 100644 --- a/traitlets/tests/test_traitlets.py +++ b/traitlets/tests/test_traitlets.py @@ -679,18 +679,18 @@ def test_observe_via_tags(self): class A(HasTraits): foo = Int() - bar = Int().tag(test=True) + bar = Int().tag(type=None) a = A() def _test_observer(change): a.foo = change['new'] - a.observe(_test_observer, tags={'test':True}) + a.observe(_test_observer, tags={'type': lambda v: v is None}) a.bar = 1 self.assertEqual(a.foo, a.bar) - a.add_traits(baz=Int().tag(test=True)) + a.add_traits(baz=Int().tag(type=None)) a.baz = 2 self.assertEqual(a.foo, a.baz) diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index a17b604b..e18e98b8 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -178,6 +178,23 @@ def parse_notifier_name(names): return names +class _TaggedNotifierContainer: + + def __init__(self, value): + self.notifiers = [] + if isinstance(value, types.FunctionType): + self.eval = True + else: + self.eval = False + self.value = value + + def __eq__(self, other): + if not self.eval or isinstance(other, types.FunctionType): + return self.value == other + else: + return self.value(other) + + class _SimpleTest: def __init__ ( self, value ): self.value = value def __call__ ( self, test ): @@ -1034,17 +1051,19 @@ def notify_change(self, change): trait = getattr(self.__class__, name) for k, v in trait.metadata.items(): try: - for n in self._trait_notifiers['tags'][k][v][type]: - if n not in callables: - callables.append(n) + if type in self._trait_notifiers['tags']: + d = self._trait_notifiers['tags'][type] + else: + d = self._trait_notifiers['tags'][All] + contianer = d[k][d[k].index(v)] except KeyError: pass - try: - for n in self._trait_notifiers['tags'][k][v][All]: - if n not in callables: - callables.append(n) - except KeyError: + except ValueError: pass + else: + for c in contianer.notifiers: + if c not in callables: + callables.append(c) # Now static ones magic_name = '_%s_changed' % name @@ -1086,22 +1105,25 @@ def _add_notifiers(self, handler, name, tags, type): if tags: tagged = self._trait_notifiers['tags'] + if type in tagged: + d = tagged[type] + else: + d = {} + tagged[type] = d for k, v in tags.items(): - if k in tagged and v in tagged[k]: - d = tagged[k][v] + if k in d: + l = d[k] else: - d = {} - if k in tagged: - tagged[k][v] = d - else: - tagged[k] = {v: d} - if type not in d: - nlist = [] - d[type] = nlist - else: - nlist = d[type] - if handler not in nlist: - nlist.append(handler) + l = [] + d[k] = l + if v in l: + i = l.index(v) + c = l[i] + else: + c = _TaggedNotifierContainer(v) + l.append(c) + if handler not in c.notifiers: + c.notifiers.append(handler) def _remove_notifiers(self, handler, name, tags, type): try: @@ -1109,24 +1131,28 @@ def _remove_notifiers(self, handler, name, tags, type): del self._trait_notifiers['names'][name][type] else: self._trait_notifiers['names'][name][type].remove(handler) - if name is not All: - trait = getattr(self.__class__, name, None) - if trait is not None: - for k, v in trait.metadata.items(): - if handler is None: - del self._trait_notifiers['tags'][k][v] - else: - self._trait_notifiers['tags'][k][v].remove(handler) - if tags: - for k, v in tags.items(): - if handler is None: - del self._trait_notifiers['tags'][k][v] - else: - self._trait_notifiers['tags'][k][v].remove(handler) except KeyError: pass - except AttributeError: - pass + + if name is not All: + trait = getattr(self.__class__, name, None) + if isinstance(trait, TraitType): + for k, v in trait.metadata.items(): + try: + if type in self._trait_notifiers['tags']: + d = self._trait_notifiers['tags'][type] + else: + d = self._trait_notifiers['tags'][All] + i = d[k].index(v) + except KeyError: + pass + except ValueError: + pass + else: + if hander is None: + d[k].remove(i) + else: + d[k][i].remove(handler) def on_trait_change(self, handler=None, name=None, remove=False): """DEPRECATED: Setup a handler to be called when a trait changes.