Skip to content

Commit

Permalink
misc and tests for validate and default via tags
Browse files Browse the repository at this point in the history
  • Loading branch information
rmorshea committed Jul 1, 2016
1 parent 3b96139 commit b4c9131
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 24 deletions.
52 changes: 51 additions & 1 deletion traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,6 @@ class A(HasTraits):
bar = Int().tag(type='a', obj=u)
baz = Int().tag(type='z')


a = A()

def _test_observer1(change):
Expand Down Expand Up @@ -626,6 +625,57 @@ def _test_observer(change):
a.bar = 2
self.assertEqual(a.foo, 0)

def test_validate_via_tags(self):

def _domain_like(x):
return isinstance(x, (tuple, list)) and len(x) == 2

class A(HasTraits):
foo = Int().tag(domain=(1, 11))
bar = Int().tag(domain=(4, 8))
baz = Int().tag(domain=None)

@validate(tags={'domain': lambda x: _domain_like(x)})
def _domain_type_coecer(self, prop):
d = prop.trait.metadata['domain']
if prop.value <= d[0]:
return d[0]
elif prop.value >= d[1]:
return d[1]-1
else:
return prop.value

a = A()

a.foo = 0
self.assertEqual(a.foo, 1)
a.foo = 11
self.assertEqual(a.foo, 10)
a.foo = 6
self.assertEqual(a.foo, 6)

a.bar = 0
self.assertEqual(a.bar, 4)
a.bar = 11
self.assertEqual(a.bar, 7)
a.bar = 6
self.assertEqual(a.bar, 6)

def test_default_via_tags(self):

class A(HasTraits):
foo = Int().tag(type='a')
bar = Int().tag(type='a')
baz = Int().tag(type='b')

@default(tags={'type': 'a'})
def _a_type_default(self):
return 1

a = A()
self.assertEqual(a.foo, 1)
self.assertEqual(a.bar, 1)
self.assertEqual(a.baz, 0)

def test_subclass(self):

Expand Down
54 changes: 31 additions & 23 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
# Basic classes
#-----------------------------------------------------------------------------


Undefined = Sentinel('Undefined', 'traitlets',
'''
Used in Traitlets to specify that no defaults are set in kwargs
Expand Down Expand Up @@ -182,15 +181,13 @@ def parse_notifier_name(names):
return []
elif names is All or isinstance(names, six.string_types):
return [names]
elif not names or All in names:
return [All]
else:
try:
names = list(names)
except:
raise TypeError("could not coerce to 'list'")
for n in names:
if n is not All and not isinstance(n, six.string_types):
raise ValueError("names must be strings or %r" % All)
return names
if not isinstance(n, six.string_types):
raise TypeError("names must be strings, not %s" % n)
return list(names)


def parse_notifier_tags(obj, tags):
Expand All @@ -203,10 +200,17 @@ def parse_notifier_tags(obj, tags):
tags: dict
The tags being converted to trait names
"""
if isinstance(obj, HasTraits):
method = obj.trait_names
elif issubclass(obj, HasTraits):
method = obj.class_trait_names
else:
raise TypeError("Expected an instance or class from a HasTraits subclass")

if tags is None or not len(tags):
return []
else:
return list(obj.trait_names(**tags))
return list(method(**tags))


class _SimpleTest:
Expand Down Expand Up @@ -910,13 +914,13 @@ def __get__(self, inst, cls=None):
class TraitEventHandler(EventHandler):

metadata = {}
trait_names = ()
trait_names = []

def __init__(self, names, tags):
if names: self.trait_names = parse_notifier_name(names)
if tags: self.metadata = tags
if names: self.trait_names = names

def register(self, inst, names=None):
def register(self, inst):
"""Associate this event with traits on an instance"""
pass

Expand All @@ -937,12 +941,11 @@ def event_of(self, obj, names=None):
-------
A ``set`` of trait names
"""
named = parse_notifier_name(self.trait_names)
tagged = parse_notifier_tags(obj, self.metadata)
if names is not None:
return set(names).intersection(named + tagged)
return set(names).intersection(self.trait_names + tagged)
else:
return set(named + tagged)
return set(self.trait_names + tagged)

def __eq__(self, other):
if isinstance(other, TraitEventHandler):
Expand Down Expand Up @@ -1416,19 +1419,24 @@ def unobserve_all(self, name=All, type=All):
self._trait_notifiers = {}
self._static_trait_notifiers = []
else:
names = self.trait_names() if name is All else [name]
try:
for n in self.trait_names():
types = self._trait_notifiers[n] if type is All else [type]
for t in types:
for n in self._trait_notifiers if name is All else [name]:
if type is All:
try:
tnames = self._trait_notifiers[n]
except:
continue
else:
tnames = [type]
for t in tnames:
try:
handlers = self._trait_notifiers[n][t]
except KeyError:
pass
else:
del self._trait_notifiers[n][t]
for h in handlers:
if h.name is None:
self._static_trait_notifiers.remove(h)
except KeyError:
pass


def _register_validator(self, handler, names):
"""Setup a handler to be called when a trait should be cross validated.
Expand Down

0 comments on commit b4c9131

Please sign in to comment.