Skip to content

Commit

Permalink
Merge pull request #10 from SylvainCorlay/DictTraitValidation
Browse files Browse the repository at this point in the history
Adding key-specific dict validation
  • Loading branch information
minrk committed May 27, 2015
2 parents bf444ba + 80bd790 commit 5369bd4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 11 deletions.
10 changes: 6 additions & 4 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,15 +1093,17 @@ def test_dict_assignment():

class ValidatedDictTrait(HasTraits):

value = Dict(Unicode())
value = Dict(trait=Unicode(),
traits={'foo': Int()},
default_value={'foo': 1})

class TestInstanceDict(TraitTestBase):

obj = ValidatedDictTrait()

_default_value = {}
_good_values = [{'0': 'foo'}, {'1': 'bar'}]
_bad_values = [{'0': 0}, {'1': 1}]
_default_value = {'foo': 1}
_good_values = [{'0': 'foo', 'foo': 1}, {'1': 'bar', 'foo': 2}]
_bad_values = [{'0': 0, 'foo': 1}, {'1': 'bar', 'foo': 'bar'}]


def test_dict_default_value():
Expand Down
41 changes: 34 additions & 7 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,25 +1761,33 @@ class Dict(Instance):
"""An instance of a Python dict."""
_trait = None

def __init__(self, trait=None, default_value=NoDefaultSpecified, **metadata):
def __init__(self, trait=None, traits=None, default_value=NoDefaultSpecified,
**metadata):
"""Create a dict trait type from a dict.
The default value is created by doing ``dict(default_value)``,
which creates a copy of the ``default_value``.
trait : TraitType [ optional ]
the type for restricting the contents of the Container. If unspecified,
types are not checked.
The type for restricting the contents of the Container. If
unspecified, types are not checked.
traits : Dictionary of trait types [optional]
The type for restricting the content of the Dictionary for certain
keys.
default_value : SequenceType [ optional ]
The default value for the Dict. Must be dict, tuple, or None, and
will be cast to a dict if not None. If `trait` is specified, the
`default_value` must conform to the constraints it specifies.
"""
# Handling positional arguments
if default_value is NoDefaultSpecified and trait is not None:
if not is_trait(trait):
default_value = trait
trait = None

# Handling default value
if default_value is NoDefaultSpecified:
default_value = {}
if default_value is None:
Expand All @@ -1791,13 +1799,19 @@ def __init__(self, trait=None, default_value=NoDefaultSpecified, **metadata):
else:
raise TypeError('default value of Dict was %s' % default_value)

# Case where a type of TraitType is provided rather than an instance
if is_trait(trait):
self._trait = trait() if isinstance(trait, type) else trait
self._trait.name = 'element'
elif trait is not None:
raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
raise TypeError("`trait` must be a Trait or None, got %s" % repr_type(trait))

super(Dict,self).__init__(klass=dict, args=args, **metadata)
self._traits = traits
if traits is not None:
for t in traits.values():
t.name = 'element'

super(Dict, self).__init__(klass=dict, args=args, **metadata)

def element_error(self, obj, element, validator):
e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
Expand All @@ -1812,13 +1826,22 @@ def validate(self, obj, value):
return value

def validate_elements(self, obj, value):
if self._trait is None or isinstance(self._trait, Any):
if self._traits is not None:
for key in self._traits:
if key not in value:
raise TraitError("Missing required '%s' key for the '%s' trait of %s instance"
% (key, self.name, class_of(obj)))
if self._traits is None and (self._trait is None or
isinstance(self._trait, Any)):
return value
validated = {}
for key in value:
v = value[key]
try:
v = self._trait._validate(obj, v)
if self._traits is not None and key in self._traits:
v = self._traits[key]._validate(obj, v)
else:
v = self._trait._validate(obj, v)
except TraitError:
self.element_error(obj, v, self._trait)
else:
Expand All @@ -1829,6 +1852,10 @@ def instance_init(self):
if isinstance(self._trait, TraitType):
self._trait.this_class = self.this_class
self._trait.instance_init()
if self._traits is not None:
for trait in self._traits.values():
trait.this_class = self.this_class
trait.instance_init()
super(Dict, self).instance_init()


Expand Down

0 comments on commit 5369bd4

Please sign in to comment.