diff --git a/traitlets/tests/test_traitlets.py b/traitlets/tests/test_traitlets.py index f44246e29..879517111 100644 --- a/traitlets/tests/test_traitlets.py +++ b/traitlets/tests/test_traitlets.py @@ -534,6 +534,99 @@ class A(HasTraits): self.assertTrue(change in self._notify1) self.assertRaises(TraitError,setattr,a,'a','bad string') + def test_observe_decorator_via_tags(self): + + class A(HasTraits): + foo = Int() + bar = Int().tag(test=True) + + @observe(tags={'test':True}) + def _test_observer(self, change): + self.foo = change['new'] + + a = A() + a.bar = 1 + self.assertEqual(a.foo, a.bar) + + a.add_traits(baz=Int().tag(test=True)) + a.baz = 2 + self.assertEqual(a.foo, a.baz) + + def test_observe_via_tags(self): + + class unhashable(object): + __hash__ = None + u = unhashable() + + class A(HasTraits): + foo = Int() + bar = Int().tag(type='a', obj=u) + baz = Int().tag(type='z') + + + a = A() + + def _test_observer1(change): + a.foo += 1 + def _test_observer2(change): + a.foo += 1 + + # test that multiple evals will register together + + a.observe(_test_observer1, tags={'type': lambda v: v in ('a','c')}) + a.observe(_test_observer2, tags={'type': lambda v: v in ('a','b')}) + + a.bar = 1 + self.assertEqual(a.foo, 2) + a.foo = 0 + + a.unobserve_all() + + # test that hashable and unhashable tags register + a.observe(_test_observer1, tags={'type': 'a'}) + a.observe(_test_observer2, tags={'obj': u}) + + a.bar = 2 + self.assertEqual(a.foo, 2) + a.foo = 0 + + a.unobserve_all() + + # test that tagged notifiers know + # about dynamically added traits + a.observe(_test_observer1, tags={'type': 'b'}) + a.add_traits(baz=Int().tag(type='b')) + + a.baz = 1 + self.assertEqual(a.foo, 1) + a.foo = 0 + + a.unobserve_all() + + def test_unobserve_via_tags(self): + + class A(HasTraits): + foo = Int() + bar = Int().tag(type='a') + + a = A() + + def _test_observer(change): + a.foo += 1 + + a.observe(_test_observer, tags={'type': 'a'}) + a.unobserve(_test_observer, names='bar') + + a.bar = 1 + self.assertEqual(a.foo, 0) + + a.observe(_test_observer, tags={'type': 'a'}) + a.unobserve(_test_observer, tags={'type': 'a'}) + + a.bar = 2 + self.assertEqual(a.foo, 0) + + def test_subclass(self): class A(HasTraits): diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index 7815199dd..132a4c2e5 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -164,10 +164,10 @@ def is_trait(t): def parse_notifier_name(names): """Convert the name argument to a list of names. - Examples -------- - + >>> parse_notifier_name(None) + [] >>> parse_notifier_name([]) [All] >>> parse_notifier_name('a') @@ -177,16 +177,37 @@ def parse_notifier_name(names): >>> parse_notifier_name(All) [All] """ - if names is All or isinstance(names, six.string_types): + if names is None: + return [] + elif names is All or isinstance(names, six.string_types): return [names] - elif isinstance(names, (list, tuple)): - if not names or All in names: - return [All] + else: + try: + names = list(names) + except: + raise TypeError("Expected coercable to 'list', not %r" % names) for n in names: - assert isinstance(n, six.string_types), "names must be strings" + if n is not All and not isinstance(n, six.string_types): + raise ValueError("names must be strings or %r" % All) return names +def parse_notifier_tags(obj, tags): + """Convert the tags argument to a list of names. + + Parameters + ---------- + obj: HasTraits instance or class + The object to which the tags apply + tags: dict + The tags being converted to trait names + """ + if tags is None or not len(tags): + return [] + else: + return obj.trait_names(**tags) + + class _SimpleTest: def __init__ ( self, value ): self.value = value def __call__ ( self, test ): @@ -759,7 +780,8 @@ def observe(*names, **kwargs): *names The str names of the Traits to observe on the object. """ - return ObserveHandler(names, type=kwargs.get('type', 'change')) + return ObserveHandler(names, tags=kwargs.get('tags', {}), + type=kwargs.get('type', 'change')) def observe_compat(func): @@ -793,7 +815,7 @@ def compatible_observer(self, change_or_name, old=Undefined, new=Undefined): return compatible_observer -def validate(*names): +def validate(*names, **kwargs): """A decorator to register cross validator of HasTraits object's state when a Trait is set. @@ -817,10 +839,10 @@ def validate(*names): exiting the ``hold_trait_notifications` context, and such changes may not commute. """ - return ValidateHandler(names) + return ValidateHandler(names, tags=kwargs.get('tags', {})) -def default(name): +def default(*names, **kwargs): """ A decorator which assigns a dynamic default for a Trait on a HasTraits object. Parameters @@ -858,18 +880,20 @@ def some_other_default(self): # This default generator should not be return 3.0 # ignored since it is defined in a # class derived from B.a.this_class. """ - return DefaultHandler(name) + return DefaultHandler(names, tags=kwargs.get('tags', {})) class EventHandler(BaseDescriptor): + func = None + def _init_call(self, func): self.func = func return self def __call__(self, *args, **kwargs): """Pass `*args` and `**kwargs` to the handler's funciton if it exists.""" - if hasattr(self, 'func'): + if self.func is not None: return self.func(*args, **kwargs) else: return self._init_call(*args, **kwargs) @@ -880,33 +904,131 @@ def __get__(self, inst, cls=None): return types.MethodType(self.func, inst) -class ObserveHandler(EventHandler): +class TraitEventHandler(EventHandler): + + metadata = {} + trait_names = () + + def __init__(self, names, tags): + if tags: self.metadata = tags + if names: self.trait_names = names + + def register(self, inst, names=None): + """Associate this event with traits on an instance""" + pass + + def event_of(self, obj, names=None): + """Get the object's trait names this instance is an event of + + Parameters + ---------- + obj: HasTraits instance or class + The instance or class being examined + name: list, tuple + A list of those trait's names. Only those names which + this instance is an event of will be returned. If no + names are given, then all the trait names of this event + are returned instead. + + Returns + ------- + A ``set`` of trait names + """ + named = parse_notifier_name(self.trait_names) + tagged = parse_notifier_tags(obj, self.metadata) + if names is not None: + return set(names).intersection(named + tagged) + else: + return set(named + tagged) + + def __eq__(self, other): + if isinstance(other, TraitEventHandler): + return self is other + else: + return other == self.func + + +class ObserveHandler(TraitEventHandler): - def __init__(self, names, type): - self.trait_names = names + caches_instance_resources = True + + def __init__(self, names, tags, type): + super(ObserveHandler, self).__init__(names, tags) self.type = type def instance_init(self, inst): - inst.observe(self, self.trait_names, type=self.type) + self.register(inst) + + def register(self, inst, names=None): + if names is not None: + matched = self.event_of(inst, names) + if len(matched)!=len(names): + diff = set(names).difference(matched) + m = "Not handling events for %r" + raise TraitError(m % list(diff)) + else: + names = list(matched) + else: + names = self.event_of(inst) + inst.observe(self, names, type=self.type) -class ValidateHandler(EventHandler): +class ValidateHandler(TraitEventHandler): - def __init__(self, names): - self.trait_names = names + caches_instance_resources = True def instance_init(self, inst): - inst._register_validator(self, self.trait_names) + self.register(inst) + + def register(self, inst, names=None): + if names is not None: + matched = self.event_of(inst, names) + if len(matched)!=names: + diff = set(names).difference(matched) + m = "Not handling events for %r" + raise TraitError(m % diff) + else: + names = list(matched) + else: + names = self.event_of(inst) + inst._register_validator(self, names) -class DefaultHandler(EventHandler): +class DefaultHandler(TraitEventHandler): - def __init__(self, name): - self.trait_name = name + caches_class_resources = True def class_init(self, cls, name): super(DefaultHandler, self).class_init(cls, name) - cls._trait_default_generators[self.trait_name] = self + self.register(cls) + + def register(self, cls, names=None): + """Associate this event with traits on a class""" + if not issubclass(cls, HasTraits): + raise TypeError("Expected a HasTraits subclass") + elif names is not None: + matched = self.event_of(cls, names) + if len(matched)!=names: + diff = set(names).difference(matched) + m = "Not handling events for %r" + raise TraitError(m % diff) + else: + names = list(matched) + else: + names = self.event_of(cls) + self._register(cls, names) + + def _register(self, cls, names): + # Class creation prevents registration logic from being + # properly overridden if done on the class, because super + # fails. Thus the event is registered here, through the + # descriptor, so super can be used when overriding it. + for n in names: + if n in cls._trait_default_generators: + raise TraitError("%s has a conflicting default" + " generator for '%s'" % (cls.__name__, n)) + else: + cls._trait_default_generators[n] = self class HasDescriptors(six.with_metaclass(MetaHasDescriptors, object)): @@ -949,6 +1071,7 @@ def setup_instance(self, *args, **kwargs): self._trait_values = {} self._trait_notifiers = {} self._trait_validators = {} + self._static_trait_notifiers = [] super(HasTraits, self).setup_instance(*args, **kwargs) def __init__(self, *args, **kwargs): @@ -990,6 +1113,7 @@ def __getstate__(self): # recall of instance_init during __setstate__ d['_trait_notifiers'] = {} d['_trait_validators'] = {} + d['_static_trait_notifiers'] = [] return d def __setstate__(self, state): @@ -1135,32 +1259,36 @@ def notify_change(self, change): if isinstance(c, _CallbackWrapper): c = c.__call__ - elif isinstance(c, EventHandler) and c.name is not None: + elif isinstance(c, TraitEventHandler) and c.name is not None: c = getattr(self, c.name) c(change) - def _add_notifiers(self, handler, name, type): - if name not in self._trait_notifiers: - nlist = [] - self._trait_notifiers[name] = {type: nlist} - else: - if type not in self._trait_notifiers[name]: + def _add_notifiers(self, handler, names, type): + for name in names: + if name not in self._trait_notifiers: nlist = [] - self._trait_notifiers[name][type] = nlist + self._trait_notifiers[name] = {type: nlist} else: - nlist = self._trait_notifiers[name][type] - if handler not in nlist: - nlist.append(handler) + if type not in self._trait_notifiers[name]: + nlist = [] + self._trait_notifiers[name][type] = nlist + else: + nlist = self._trait_notifiers[name][type] + if handler not in nlist: + nlist.append(handler) + if handler.name is None and handler not in self._static_trait_notifiers: + self._static_trait_notifiers.append(handler) - def _remove_notifiers(self, handler, name, type): - try: - if handler is None: - del self._trait_notifiers[name][type] - else: - self._trait_notifiers[name][type].remove(handler) - except KeyError: - pass + def _remove_notifiers(self, handler, names, type): + for name in names: + try: + i = self._trait_notifiers[name][type].index(handler) + event = self._trait_notifiers[name][type].pop(i) + except KeyError: + pass + if event.name is None: + self._static_trait_notifiers.remove(handler) def on_trait_change(self, handler=None, name=None, remove=False): """DEPRECATED: Setup a handler to be called when a trait changes. @@ -1199,7 +1327,7 @@ def on_trait_change(self, handler=None, name=None, remove=False): else: self.observe(_callback_wrapper(handler), names=name) - def observe(self, handler, names=All, type='change'): + def observe(self, handler, names=None, tags=None, type='change'): """Setup a handler to be called when a trait changes. This is used to setup dynamic notifications of trait changes. @@ -1225,11 +1353,19 @@ def observe(self, handler, names=All, type='change'): The type of notification to filter by. If equal to All, then all notifications are passed to the observe handler. """ - names = parse_notifier_name(names) - for n in names: - self._add_notifiers(handler, n, type) + if isinstance(tags, six.string_types): + warn("new argument 'tags' introduced: use 'type' as keyword", + DeprecationWarning, stacklevel=2) + type, tags = tags, None + if names is None and tags is None: + names = All + if isinstance(handler, ObserveHandler): + self._add_notifiers(handler, names, type) + else: + event = ObserveHandler(names, tags, type) + event(handler).register(self) - def unobserve(self, handler, names=All, type='change'): + def unobserve(self, handler, names=None, tags=None, type='change'): """Remove a trait change handler. This is used to unregister handlers to trait change notificiations. @@ -1246,21 +1382,48 @@ def unobserve(self, handler, names=All, type='change'): The type of notification to filter by. If All, the specified handler is uninstalled from the list of notifiers corresponding to all types. """ - names = parse_notifier_name(names) - for n in names: - self._remove_notifiers(handler, n, type) + if isinstance(tags, six.string_types): + warn("new argument 'tags' introduced: use 'type' as keyword", + DeprecationWarning, stacklevel=2) + type, tags = tags, None + if names is None and tags is None: + names = All + named = parse_notifier_name(names) + tagged = parse_notifier_tags(self, tags) + self._remove_notifiers(handler, named + tagged, type) + + def unobserve_all(self, name=All, type=All): + """Remove trait change handlers of for the given name and type. - def unobserve_all(self, name=All): - """Remove trait change handlers of any type for the specified name. - If name is not specified, removes all trait notifiers.""" - if name is All: + Parameters + ---------- + name: str, All + If name is All, then all handlers of the given type are removed. + If a name is specified, then all handlers of that name and type + are removed instead. + type: str, All + If type is All, then all handlers of the given name are removed. + If a type is specified, then all handlers of that type and name + are removed instead. + """ + if name is All and type is All: self._trait_notifiers = {} + self._static_trait_notifiers = [] else: + names = self.trait_names() if name is All else [name] try: - del self._trait_notifiers[name] + for n in self.trait_names(): + types = self._trait_notifiers[n] if type is All else [type] + for t in types: + handlers = self._trait_notifiers[n][t] + del self._trait_notifiers[n][t] + for h in handlers: + if h.name is None: + self._static_trait_notifiers.remove(h) except KeyError: pass + def _register_validator(self, handler, names): """Setup a handler to be called when a trait should be cross valdiated. @@ -1295,6 +1458,15 @@ def add_traits(self, **traits): """Dynamically add trait attributes to the HasTraits instance.""" self.__class__ = type(self.__class__.__name__, (self.__class__,), traits) + + events = self._static_trait_notifiers + events.extend(self.trait_events().values()) + + for e in events: + if e.caches_instance_resources: + for n in e.event_of(self, traits.keys()): + e.register(self, (n,)) + for trait in traits.values(): trait.instance_init(self) @@ -1325,7 +1497,7 @@ def has_trait(cls, name): @classmethod def trait_names(cls, **metadata): """Get a list of all the names of this class' traits.""" - return cls.traits(**metadata).keys() + return list(cls.traits(**metadata).keys()) @classmethod def class_trait_names(cls, **metadata): @@ -1336,7 +1508,7 @@ def class_trait_names(cls, **metadata): """ warn("``HasTraits.class_trait_names`` is deprecated in favor of ``HasTraits.trait_names``" " as a classmethod", DeprecationWarning, stacklevel=2) - return cls.traits(**metadata).keys() + return list(cls.traits(**metadata).keys()) @classmethod def class_traits(cls, **metadata): @@ -2135,7 +2307,7 @@ def __init__(self, trait=None, default_value=None, **metadata): if is_trait(trait): if isinstance(trait, type): warn("Traits should be given as instances, not types (for example, `Int()`, not `Int`)", - DeprecationWarning, stacklevel=3) + DeprecationWarning, stacklevel=2) self._trait = trait() if isinstance(trait, type) else trait elif trait is not None: raise TypeError("`trait` must be a Trait or None, got %s" % repr_type(trait))