Skip to content
Browse files

Allow type checking on elements of List,Tuple,Set

  • Loading branch information...
1 parent 3273c8c commit 4c91ed72e1e6028c95fd24a806c868b0b9ff01cf @minrk committed
Showing with 184 additions and 33 deletions.
  1. +39 −1 IPython/utils/tests/test_traitlets.py
  2. +145 −32 IPython/utils/traitlets.py
View
40 IPython/utils/tests/test_traitlets.py
@@ -27,7 +27,7 @@
from IPython.utils.traitlets import (
HasTraits, MetaHasTraits, TraitType, Any,
Int, Long, Float, Complex, Str, Unicode, TraitError,
- Undefined, Type, This, Instance, TCPAddress
+ Undefined, Type, This, Instance, TCPAddress, List
)
@@ -739,3 +739,41 @@ class TestTCPAddress(TraitTestBase):
_default_value = ('127.0.0.1',0)
_good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
_bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
+
+class TypedListTrait(HasTraits):
+
+ value = List(Int)
+
+class TestTypedList(TraitTestBase):
+
+ obj = TypedListTrait()
+
+ _default_value = []
+ _good_values = [[], [1], range(10)]
+ _bad_values = [10, [1,'a'], 'a', (1,2)]
+
+class LengthTypedListTrait(HasTraits):
+
+ value = List([Int])
+
+class TestLengthTypedList(TraitTestBase):
+
+ obj = LengthTypedListTrait()
+
+ _default_value = None
+ _good_values = [[1], None,[0]]
+ _bad_values = [10, [1,2], (1,),['a'], []]
+
+
+class MultiTypedListTrait(HasTraits):
+
+ value = List((Int, Str))
+
+class TestMultiTypedList(TraitTestBase):
+
+ obj = MultiTypedListTrait()
+
+ _default_value = None
+ _good_values = [[1,'a'], [2,'b']]
+ _bad_values = [[],10, 'a', [1,'a',3], ['a',1]]
+
View
177 IPython/utils/traitlets.py
@@ -328,7 +328,7 @@ def error(self, obj, value):
self.info(), repr_type(value))
else:
e = "The '%s' trait must be %s, but a value of %r was specified." \
- % (self.name, self.info(), repr_type(value))
+ % (self.name, self.info(), repr_type(value))
raise TraitError(e)
def get_metadata(self, key):
@@ -579,7 +579,15 @@ def error(self, obj, value):
else:
msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
- super(ClassBasedTraitType, self).error(obj, msg)
+ if obj is not None:
+ e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
+ % (self.name, class_of(obj),
+ self.info(), msg)
+ else:
+ e = "The '%s' trait must be %s, but a value of %r was specified." \
+ % (self.name, self.info(), msg)
+
+ raise TraitError(e)
class Type(ClassBasedTraitType):
@@ -1013,46 +1021,152 @@ def validate(self, obj, value):
return v
self.error(obj, value)
+class Container(Instance):
+ """An instance of a container (list, set, etc.)
+
+ To be subclassed by overriding klass.
+ """
+ _klass = None
+ _valid_defaults = SequenceTypes
+ _types = None
+ _validators = None
+ _length = None
+
+ def __init__(self, default_value=None, allow_none=True, types=None,
+ **metadata):
+ """Create a container trait type from a list, set, or tuple.
-class List(Instance):
- """An instance of a Python list."""
-
- def __init__(self, default_value=None, allow_none=True, **metadata):
- """Create a list trait type from a list, set, or tuple.
-
- The default value is created by doing ``list(default_value)``,
+ The default value is created by doing ``<self._klass>(default_value)``,
which creates a copy of the ``default_value``.
"""
+ istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
+
if default_value is None:
- args = ((),)
- elif isinstance(default_value, SequenceTypes):
- args = (default_value,)
+ args = ()
+ elif isinstance(default_value, self._valid_defaults):
+ if types is None and len(default_value) and\
+ all([ istrait(t) for t in default_value ]):
+ types = default_value
+ args = ()
+ else:
+ args = (default_value,)
+ elif istrait(default_value):
+ if types is None:
+ types = default_value
+ args = ()
+ else:
+ args = (default_value,)
else:
- raise TypeError('default value of List was %s' % default_value)
-
- super(List,self).__init__(klass=list, args=args,
+ raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
+
+ if types:
+ self._build_validators(types)
+
+ if self._length and args == ():
+ # don't allow default to be an empty container if length is specified
+ args = None
+ super(Container,self).__init__(klass=self._klass, args=args,
allow_none=allow_none, **metadata)
+
+ def _build_validators(self, trait_types):
+ """build empty Traits for validating elements"""
+ if isinstance(trait_types, SequenceTypes):
+ self._validators = []
+ for t in trait_types:
+ v = t()
+ v.name = "element"
+ self._validators.append(v)
+ self._length = len(trait_types)
+ else:
+ # single trait type
+ v = trait_types()
+ v.name = "element"
+ self._validators = v
+ self._length = None
+
+ def validator_info(self):
+ """info for element types"""
+ if self._length:
+ return ' or '.join([ v.info() for v in self._validators ])
+ else:
+ return self._validators.info()
+
+ def info(self):
+ if isinstance(self.klass, basestring):
+ klass = self.klass
+ else:
+ klass = self.klass.__name__
+ result = class_of(klass)
+ if self._length:
+ result = result + ' of length %i'%self._length
+ if self._allow_none:
+ return result + ' or None'
+ return result
+
+ def element_error(self, obj, element):
+ e = "Elements of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
+ % (self.name, class_of(obj), self.validator_info(), repr_type(element))
+ raise TraitError(e)
+
+ def validate(self, obj, value):
+ if value is None:
+ if self._allow_none:
+ return value
+ self.error(obj, value)
+
+ if not isinstance(value, self._klass):
+ self.error(obj, value)
+
+ if self._length is not None and len(value) != self._length:
+ self.error(obj, value)
+
+ if self._validators:
+ self.validate_elements(obj, value)
+
+ return value
+
+ def get_validator(self, key):
+ if not isinstance(self._validators, list):
+ return self._validators
+ else:
+ return self._validators[key]
+
+ def validate_elements(self, obj, value):
+ validated = []
+ for i,v in enumerate(value):
+ validator = self.get_validator(i)
+ try:
+ v = validator.validate(obj, v)
+ except TraitError:
+ self.element_error(obj, v)
+ else:
+ validated.append(v)
+ return self._klass(validated)
+
-class Set(Instance):
- """An instance of a Python set."""
+class List(Container):
+ """An instance of a Python list."""
+ _klass = list
- def __init__(self, default_value=None, allow_none=True, **metadata):
- """Create a set trait type from a set, list, or tuple.
+class Tuple(Container):
+ """An instance of a Python tuple."""
+ _klass = tuple
- The default value is created by doing ``set(default_value)``,
- which creates a copy of the ``default_value``.
- """
- if default_value is None:
- args = ((),)
- elif isinstance(default_value, SequenceTypes):
- args = (default_value,)
+class Set(Container):
+ """An instance of a Python set."""
+ _klass = set
+
+ def _build_validators(self, trait_types):
+ """build empty Traits for validating elements"""
+ if isinstance(trait_types, SequenceTypes):
+ raise TraitError("Only one type may be enforced on Sets")
else:
- raise TypeError('default value of Set was %s' % default_value)
-
- super(Set,self).__init__(klass=set, args=args,
- allow_none=allow_none, **metadata)
-
+ # single trait type
+ v = trait_types()
+ v.name = "element"
+ self._validators = v
+ self._length = None
class Dict(Instance):
"""An instance of a Python dict."""
@@ -1075,7 +1189,6 @@ def __init__(self, default_value=None, allow_none=True, **metadata):
super(Dict,self).__init__(klass=dict, args=args,
allow_none=allow_none, **metadata)
-
class TCPAddress(TraitType):
"""A trait for an (ip, port) tuple.

0 comments on commit 4c91ed7

Please sign in to comment.
Something went wrong with that request. Please try again.