diff --git a/docs/source/traits_api_reference/has_traits.rst b/docs/source/traits_api_reference/has_traits.rst index 1defb67145..8d83a7c120 100644 --- a/docs/source/traits_api_reference/has_traits.rst +++ b/docs/source/traits_api_reference/has_traits.rst @@ -37,6 +37,8 @@ Classes .. autoclass:: HasStrictTraits +.. autoclass:: HasRequiredTraits + .. autoclass:: HasPrivateTraits .. autoclass:: SingletonHasTraits diff --git a/docs/source/traits_user_manual/advanced.rst b/docs/source/traits_user_manual/advanced.rst index 3f92fedda3..06e4414469 100644 --- a/docs/source/traits_user_manual/advanced.rst +++ b/docs/source/traits_user_manual/advanced.rst @@ -301,6 +301,38 @@ exception, as does attempting to set an attribute that is not one of the three defined attributes. In essence, TreeNode behaves like a type-checked data structure. +.. index:: HasRequiredTraits class + +.. _hasrequiredtraits: + +HasRequiredTraits +''''''''''''''''' + +This class builds on the functionality of HasStrictTraits and ensures +that any object attribute with `required=True` in its metadata must be passed +as an argument on object initialization. + +An example of a class with required traits:: + + class RequiredTest(HasRequiredTraits): + required_trait = Any(required=True) + non_required_trait = Any() + +All required traits have to be provided as arguments on creating a new +instance:: + + >>> new_instance = RequiredTest(required_trait=13.0) + +Non-required traits can also still be provided as usual:: + + >>> new_instance = RequiredTest(required_trait=13.0, non_required_trait=14.0) + +However, omitting a required trait will raise a TraitError:: + + >>> new_instance = RequiredTest(non_required_trait=14.0) + traits.trait_errors.TraitError: The following required traits were not + provided: required_trait. + .. index:: HasPrivateTraits class .. _hasprivatetraits: diff --git a/traits/api.py b/traits/api.py index 20881acab2..5f2cf06bfc 100644 --- a/traits/api.py +++ b/traits/api.py @@ -65,10 +65,10 @@ from .trait_types import UUID, ValidatedTuple from .has_traits import (HasTraits, HasStrictTraits, HasPrivateTraits, - Interface, SingletonHasTraits, SingletonHasStrictTraits, - SingletonHasPrivateTraits, MetaHasTraits, Vetoable, VetoableEvent, - implements, traits_super, on_trait_change, cached_property, - property_depends_on, provides, isinterface) + HasRequiredTraits, Interface, SingletonHasTraits, + SingletonHasStrictTraits, SingletonHasPrivateTraits, MetaHasTraits, + Vetoable, VetoableEvent, implements, traits_super, on_trait_change, + cached_property, property_depends_on, provides, isinterface) try: from .has_traits import ABCHasTraits, ABCHasStrictTraits, ABCMetaHasTraits diff --git a/traits/has_traits.py b/traits/has_traits.py index 54420156ba..c5fa255873 100644 --- a/traits/has_traits.py +++ b/traits/has_traits.py @@ -3455,6 +3455,57 @@ class HasStrictTraits ( HasTraits ): """ _ = Disallow # Disallow access to any traits not explicitly defined +#------------------------------------------------------------------------------- +# 'HasRequiredTraits' class: +#------------------------------------------------------------------------------- + +class HasRequiredTraits(HasStrictTraits): + """ This class builds on the functionality of HasStrictTraits and ensures + that any object attribute with `required=True` in its metadata must be + passed as an argument on object initialization. + + This can be useful in cases where an object has traits which are required + for it to function correctly. + + Raises + ------ + TraitError + If a required trait is not passed as an argument. + + Usage + ----- + A class with required traits: + + >>> class RequiredTest(HasRequiredTraits): + ... required_trait = Any(required=True) + ... non_required_trait = Any() + + Creating an instance of a HasRequiredTraits subclass: + + >>> test_instance = RequiredTest(required_trait=13, non_required_trait=11) + >>> test_instance2 = RequiredTest(required_trait=13) + + Forgetting to specify a required trait: + + >>> test_instance = RequiredTest(non_required_trait=11) + traits.trait_errors.TraitError: The following required traits were not + provided: required_trait. + """ + + def __init__(self, **traits): + + missing_required_traits = [ + name for name in self.trait_names(required=True) + if name not in traits + ] + if missing_required_traits: + raise TraitError( + "The following required traits were not provided: " + "{}.".format(', '.join(sorted(missing_required_traits))) + ) + + super(HasRequiredTraits, self).__init__(**traits) + #------------------------------------------------------------------------------- # 'HasPrivateTraits' class: #------------------------------------------------------------------------------- diff --git a/traits/tests/test_has_required_traits.py b/traits/tests/test_has_required_traits.py new file mode 100644 index 0000000000..30ec33de68 --- /dev/null +++ b/traits/tests/test_has_required_traits.py @@ -0,0 +1,30 @@ +import unittest +from traits.api import Int, Float, String, HasRequiredTraits, TraitError + +class TestHasRequiredTraits(unittest.TestCase): + + def test_trait_value_assignment(self): + test_instance = RequiredTest( + i_trait=4, f_trait=2.2, s_trait="test") + self.assertEqual(test_instance.i_trait, 4) + self.assertEqual(test_instance.f_trait, 2.2) + self.assertEqual(test_instance.s_trait, "test") + self.assertEqual(test_instance.non_req_trait, 4.4) + self.assertEqual(test_instance.normal_trait, 42.0) + + + def test_missing_required_trait(self): + with self.assertRaises(TraitError) as exc: + test_instance = RequiredTest(i_trait=3) + self.assertEqual( + exc.exception.args[0], "The following required traits were not " + "provided: f_trait, s_trait." + ) + + +class RequiredTest(HasRequiredTraits): + i_trait = Int(required=True) + f_trait = Float(required=True) + s_trait = String(required=True) + non_req_trait = Float(4.4, required=False) + normal_trait = Float(42.0)