diff --git a/traitlets/tests/test_traitlets.py b/traitlets/tests/test_traitlets.py index 46e75745..15e1b3e1 100644 --- a/traitlets/tests/test_traitlets.py +++ b/traitlets/tests/test_traitlets.py @@ -22,7 +22,7 @@ Union, All, Undefined, Type, This, Instance, TCPAddress, List, Tuple, ObjectName, DottedObjectName, CRegExp, link, directional_link, ForwardDeclaredType, ForwardDeclaredInstance, validate, observe, default, - observe_compat, BaseDescriptor, HasDescriptors, + observe_compat, BaseDescriptor, HasDescriptors, ThisType ) import six @@ -1050,6 +1050,61 @@ class Tree(HasTraits): with self.assertRaises(TraitError): tree.leaves = [1, 2] +class TestThisType(TestCase): + + def test_this_class(self): + class Foo(HasTraits): + this_type = ThisType() + + f = Foo() + self.assertEqual(f.this_type, Foo) + f.this_type = Foo + self.assertEqual(f.this_type, Foo) + self.assertRaises(TraitError, setattr, f, 'this_type', 10) + + def test_this_inst(self): + class Foo(HasTraits): + this_type = ThisType() + + f = Foo() + f.this_type = Foo + self.assertTrue(issubclass(f.this_type, Foo)) + + def test_subclass(self): + class Foo(HasTraits): + t = ThisType() + class Bar(Foo): + pass + f = Foo() + b = Bar() + f.t = Bar + b.t = Foo + self.assertEqual(f.t, Bar) + self.assertEqual(b.t, Foo) + + def test_subclass_override(self): + class Foo(HasTraits): + t = ThisType() + class Bar(Foo): + t = ThisType() + f = Foo() + b = Bar() + f.t = Bar + self.assertEqual(f.t, Bar) + self.assertRaises(TraitError, setattr, b, 't', Foo) + + def test_this_in_container(self): + + class A(HasTraits): + types = Dict(ThisType()) + class B(A): pass + class C(B): pass + a = A( + types={'b': B, 'c': C} + ) + with self.assertRaises(TraitError): + a.types = {'b': 1, 'c': 2} + class TraitTestBase(TestCase): """A best testing class for basic trait types.""" diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index e4a250f2..9ee8adb2 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -1727,27 +1727,44 @@ class ForwardDeclaredInstance(ForwardDeclaredMixin, Instance): pass -class This(ClassBasedTraitType): - """A trait for instances of the class containing this trait. +class ThisClassMixin(object): - Because how how and when class bodies are executed, the ``This`` - trait can only have a default value of None. This, and because we - always validate default values, ``allow_none`` is *always* true. + # A temporary value until class_init is called + klass = type('UndefinedClass', (object,), {}) + + def class_init(self, cls, name): + super(ThisClassMixin, self).class_init(cls, name) + self.klass = self.this_class + + +class This(ThisClassMixin, Instance): + """A trait for instances of the class owning this trait. + + Because of how and when class bodies are executed, the ``This`` + trait holds a temporary class type until the owner class has been + setup by :class:`MetaHasDescriptors`. At which point, the final + instance type is assigned by :meth:`class_init`. """ info_text = 'an instance of the same type as the receiver or None' + allow_none = True - def __init__(self, **metadata): - super(This, self).__init__(None, **metadata) - def validate(self, obj, value): - # What if value is a superclass of obj.__class__? This is - # complicated if it was the superclass that defined the This - # trait. - if isinstance(value, self.this_class) or (value is None): - return value - else: - self.error(obj, value) +class ThisType(ThisClassMixin, Type): + """A trait for subclasses of the class owning this trait. + + Because of how and when class bodies are executed, the ``This`` + trait holds a temporary class type until the owner class has been + setup by :class:`MetaHasDescriptors`. At which point, the final + instance type is assigned by :meth:`class_init`. + """ + + def __init__(self, default_value=Undefined, **metadata): + super(ThisType, self).__init__(default_value, self.klass, **metadata) + + def class_init(self, cls, name): + super(ThisType, self).class_init(cls, name) + self.default_value = self.klass class Union(TraitType):