From b4c913192842f5e6610f8177e88041edd1a93fd6 Mon Sep 17 00:00:00 2001 From: Ryan Morshead Date: Fri, 1 Jul 2016 09:37:45 -0700 Subject: [PATCH] misc and tests for validate and default via tags --- traitlets/tests/test_traitlets.py | 52 ++++++++++++++++++++++++++++- traitlets/traitlets.py | 54 ++++++++++++++++++------------- 2 files changed, 82 insertions(+), 24 deletions(-) diff --git a/traitlets/tests/test_traitlets.py b/traitlets/tests/test_traitlets.py index 87951711..e84c91fb 100644 --- a/traitlets/tests/test_traitlets.py +++ b/traitlets/tests/test_traitlets.py @@ -563,7 +563,6 @@ class A(HasTraits): bar = Int().tag(type='a', obj=u) baz = Int().tag(type='z') - a = A() def _test_observer1(change): @@ -626,6 +625,57 @@ def _test_observer(change): a.bar = 2 self.assertEqual(a.foo, 0) + def test_validate_via_tags(self): + + def _domain_like(x): + return isinstance(x, (tuple, list)) and len(x) == 2 + + class A(HasTraits): + foo = Int().tag(domain=(1, 11)) + bar = Int().tag(domain=(4, 8)) + baz = Int().tag(domain=None) + + @validate(tags={'domain': lambda x: _domain_like(x)}) + def _domain_type_coecer(self, prop): + d = prop.trait.metadata['domain'] + if prop.value <= d[0]: + return d[0] + elif prop.value >= d[1]: + return d[1]-1 + else: + return prop.value + + a = A() + + a.foo = 0 + self.assertEqual(a.foo, 1) + a.foo = 11 + self.assertEqual(a.foo, 10) + a.foo = 6 + self.assertEqual(a.foo, 6) + + a.bar = 0 + self.assertEqual(a.bar, 4) + a.bar = 11 + self.assertEqual(a.bar, 7) + a.bar = 6 + self.assertEqual(a.bar, 6) + + def test_default_via_tags(self): + + class A(HasTraits): + foo = Int().tag(type='a') + bar = Int().tag(type='a') + baz = Int().tag(type='b') + + @default(tags={'type': 'a'}) + def _a_type_default(self): + return 1 + + a = A() + self.assertEqual(a.foo, 1) + self.assertEqual(a.bar, 1) + self.assertEqual(a.baz, 0) def test_subclass(self): diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index 2586e699..5d7463f4 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -66,7 +66,6 @@ # Basic classes #----------------------------------------------------------------------------- - Undefined = Sentinel('Undefined', 'traitlets', ''' Used in Traitlets to specify that no defaults are set in kwargs @@ -182,15 +181,13 @@ def parse_notifier_name(names): return [] elif names is All or isinstance(names, six.string_types): return [names] + elif not names or All in names: + return [All] else: - try: - names = list(names) - except: - raise TypeError("could not coerce to 'list'") for n in names: - if n is not All and not isinstance(n, six.string_types): - raise ValueError("names must be strings or %r" % All) - return names + if not isinstance(n, six.string_types): + raise TypeError("names must be strings, not %s" % n) + return list(names) def parse_notifier_tags(obj, tags): @@ -203,10 +200,17 @@ def parse_notifier_tags(obj, tags): tags: dict The tags being converted to trait names """ + if isinstance(obj, HasTraits): + method = obj.trait_names + elif issubclass(obj, HasTraits): + method = obj.class_trait_names + else: + raise TypeError("Expected an instance or class from a HasTraits subclass") + if tags is None or not len(tags): return [] else: - return list(obj.trait_names(**tags)) + return list(method(**tags)) class _SimpleTest: @@ -910,13 +914,13 @@ def __get__(self, inst, cls=None): class TraitEventHandler(EventHandler): metadata = {} - trait_names = () + trait_names = [] def __init__(self, names, tags): + if names: self.trait_names = parse_notifier_name(names) if tags: self.metadata = tags - if names: self.trait_names = names - def register(self, inst, names=None): + def register(self, inst): """Associate this event with traits on an instance""" pass @@ -937,12 +941,11 @@ def event_of(self, obj, names=None): ------- A ``set`` of trait names """ - named = parse_notifier_name(self.trait_names) tagged = parse_notifier_tags(obj, self.metadata) if names is not None: - return set(names).intersection(named + tagged) + return set(names).intersection(self.trait_names + tagged) else: - return set(named + tagged) + return set(self.trait_names + tagged) def __eq__(self, other): if isinstance(other, TraitEventHandler): @@ -1416,19 +1419,24 @@ def unobserve_all(self, name=All, type=All): self._trait_notifiers = {} self._static_trait_notifiers = [] else: - names = self.trait_names() if name is All else [name] - try: - for n in self.trait_names(): - types = self._trait_notifiers[n] if type is All else [type] - for t in types: + for n in self._trait_notifiers if name is All else [name]: + if type is All: + try: + tnames = self._trait_notifiers[n] + except: + continue + else: + tnames = [type] + for t in tnames: + try: handlers = self._trait_notifiers[n][t] + except KeyError: + pass + else: del self._trait_notifiers[n][t] for h in handlers: if h.name is None: self._static_trait_notifiers.remove(h) - except KeyError: - pass - def _register_validator(self, handler, names): """Setup a handler to be called when a trait should be cross validated.