Skip to content

Commit

Permalink
Merge pull request #31 from takluyver/instance-dynamic-defaults
Browse files Browse the repository at this point in the history
Don't create default values for Instance until required
  • Loading branch information
minrk committed Jun 15, 2015
2 parents 128cf8e + dfbd585 commit 1e89907
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 104 deletions.
22 changes: 13 additions & 9 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,24 +114,20 @@ def _x_default(self):

a = A()
self.assertEqual(a._trait_values, {})
self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
self.assertEqual(a.x, 11)
self.assertEqual(a._trait_values, {'x': 11})
b = B()
self.assertEqual(b._trait_values, {'x': 20})
self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
self.assertEqual(b.x, 20)
self.assertEqual(b._trait_values, {'x': 20})
c = C()
self.assertEqual(c._trait_values, {})
self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
self.assertEqual(c.x, 21)
self.assertEqual(c._trait_values, {'x': 21})
# Ensure that the base class remains unmolested when the _default
# initializer gets overridden in a subclass.
a = A()
c = C()
self.assertEqual(a._trait_values, {})
self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
self.assertEqual(a.x, 11)
self.assertEqual(a._trait_values, {'x': 11})

Expand Down Expand Up @@ -448,7 +444,7 @@ class A(HasTraits):
klass = Type(allow_none=True)

a = A()
self.assertEqual(a.klass, None)
self.assertEqual(a.klass, object)

a.klass = B
self.assertEqual(a.klass, B)
Expand Down Expand Up @@ -606,7 +602,9 @@ class Foo(object): pass
class A(HasTraits):
inst = Instance(Foo)

self.assertRaises(TraitError, A)
a = A()
with self.assertRaises(TraitError):
a.inst

def test_instance(self):
class Foo(object): pass
Expand Down Expand Up @@ -1110,8 +1108,14 @@ def test_dict_default_value():
"""Check that the `{}` default value of the Dict traitlet constructor is
actually copied."""

d1, d2 = Dict(), Dict()
nt.assert_false(d1.get_default_value() is d2.get_default_value())
class Foo(HasTraits):
d1 = Dict()
d2 = Dict()

foo = Foo()
nt.assert_equal(foo.d1, {})
nt.assert_equal(foo.d2, {})
nt.assert_is_not(foo.d1, foo.d2)


class TestValidationHook(TestCase):
Expand Down
155 changes: 60 additions & 95 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class link(object):
Examples
--------
>>> c = link((src, 'value'), (tgt, 'value'),
>>> c = link((src, 'value'), (tgt, 'value'))
>>> src.value = 5 # updates other objects as well
"""
updating = False
Expand Down Expand Up @@ -371,41 +371,50 @@ def init(self):
pass

def get_default_value(self):
"""Create a new instance of the default value."""
"""Retrieve the static default value for this trait"""
return self.default_value

def init_default_value(self, obj):
"""Instantiate the default value for the trait type.
def validate_default_value(self, obj):
"""Retrieve and validate the static default value"""
v = self.get_default_value()
return self._validate(obj, v)

This method is called when accessing the trait value for the first
time in :meth:`HasTraits.__get__`.
def init_default_value(self, obj):
"""DEPRECATED: Set the static default value for the trait type.
"""
value = self.get_default_value()
value = self._validate(obj, value)
warn("init_default_value is deprecated, and may be removed in the future",
stacklevel=2)
value = self.validate_default_value()
obj._trait_values[self.name] = value
return value

def _setup_dynamic_initializer(self, obj):
# Check for a deferred initializer defined in the same class as the
# trait declaration or above.
mro = type(obj).mro()
meth_name = '_%s_default' % self.name
for cls in mro[:mro.index(self.this_class)+1]:
if meth_name in cls.__dict__:
break
else:
return False
# Complete the dynamic initialization.
obj._trait_dyn_inits[self.name] = meth_name
return True

def _set_default_value_at_instance_init(self, obj):
# As above, but if no default was specified, don't try to set it.
# If the trait is accessed before it is given a value, init_default_value
# will be called at that point.
if (not self._setup_dynamic_initializer(obj)) \
def _dynamic_default_callable(self, obj):
"""Retrieve a callable to calculate the default for this traitlet.
This looks for:
- obj._{name}_default() on the class with the traitlet, or a subclass
that obj belongs to.
- trait.make_dynamic_default, which is defined by Instance
If neither exist, it returns None
"""
# Traitlets without a name are not on the instance, e.g. in List or Union
if self.name:
mro = type(obj).mro()
meth_name = '_%s_default' % self.name
for cls in mro[:mro.index(self.this_class)+1]:
if meth_name in cls.__dict__:
return getattr(obj, meth_name)

return getattr(self, 'make_dynamic_default', None)

def instance_init(self, obj):
# If no dynamic initialiser is present, and the trait implementation or
# use provides a static default, transfer that to obj._trait_values.
if (self._dynamic_default_callable(obj) is None) \
and (self.default_value is not Undefined):
self.init_default_value(obj)
self.validate_default_value(obj)

def __get__(self, obj, cls=None):
"""Get the value of the trait by self.name for the instance.
Expand All @@ -422,15 +431,13 @@ def __get__(self, obj, cls=None):
value = obj._trait_values[self.name]
except KeyError:
# Check for a dynamic initializer.
if self.name in obj._trait_dyn_inits:
method = getattr(obj, obj._trait_dyn_inits[self.name])
value = method()
# FIXME: Do we really validate here?
value = self._validate(obj, value)
obj._trait_values[self.name] = value
return value
dynamic_default = self._dynamic_default_callable(obj)
if dynamic_default is not None:
value = self._validate(obj, dynamic_default())
else:
return self.init_default_value(obj)
value = self.validate_default_value(obj)
obj._trait_values[self.name] = value
return value
except Exception:
# This should never be reached.
raise TraitError('Unexpected error in TraitType: '
Expand All @@ -443,7 +450,7 @@ def __set__(self, obj, value):
try:
old_value = obj._trait_values[self.name]
except KeyError:
old_value = Undefined
old_value = self.get_default_value()

obj._trait_values[self.name] = new_value
try:
Expand Down Expand Up @@ -554,7 +561,6 @@ def __new__(cls, *args, **kw):
inst = new_meth(cls, **kw)
inst._trait_values = {}
inst._trait_notifiers = {}
inst._trait_dyn_inits = {}
inst._cross_validation_lock = True
# Here we tell all the TraitType instances to set their default
# values on the instance.
Expand All @@ -569,8 +575,6 @@ def __new__(cls, *args, **kw):
else:
if isinstance(value, BaseDescriptor):
value.instance_init(inst)
if isinstance(value, TraitType) and key not in kw:
value._set_default_value_at_instance_init(inst)
inst._cross_validation_lock = False
return inst

Expand Down Expand Up @@ -901,7 +905,7 @@ def __init__ (self, default_value=None, klass=None, **metadata):
a particular class.
If only ``default_value`` is given, it is used for the ``klass`` as
well.
well. If neither are given, both default to ``object``.
Parameters
----------
Expand All @@ -915,13 +919,12 @@ def __init__ (self, default_value=None, klass=None, **metadata):
may be specified in a string like: 'foo.bar.MyClass'.
The string is resolved into real class, when the parent
:class:`HasTraits` class is instantiated.
allow_none : bool [ default True ]
Indicates whether None is allowed as an assignable value. Even if
``False``, the default value may be ``None``.
allow_none : bool [ default False ]
Indicates whether None is allowed as an assignable value.
"""
if default_value is None:
if klass is None:
klass = object
default_value = klass = object
elif klass is None:
klass = default_value

Expand Down Expand Up @@ -969,20 +972,6 @@ def _resolve_classes(self):
if isinstance(self.default_value, py3compat.string_types):
self.default_value = self._resolve_string(self.default_value)

def get_default_value(self):
return self.default_value


class DefaultValueGenerator(object):
"""A class for generating new default value instances."""

def __init__(self, *args, **kw):
self.args = args
self.kw = kw

def generate(self, klass):
return klass(*self.args, **self.kw)


class Instance(ClassBasedTraitType):
"""A trait whose value must be an instance of a specified class.
Expand Down Expand Up @@ -1030,25 +1019,15 @@ class or its subclasses. Our implementation is quite different
raise TraitError('The klass attribute must be a class'
' not: %r' % klass)

# self.klass is a class, so handle default_value
if args is None and kw is None:
default_value = None
else:
if args is None:
# kw is not None
args = ()
elif kw is None:
# args is not None
kw = {}

if not isinstance(kw, dict):
raise TraitError("The 'kw' argument must be a dict or None.")
if not isinstance(args, tuple):
raise TraitError("The 'args' argument must be a tuple or None.")
if (kw is not None) and not isinstance(kw, dict):
raise TraitError("The 'kw' argument must be a dict or None.")
if (args is not None) and not isinstance(args, tuple):
raise TraitError("The 'args' argument must be a tuple or None.")

default_value = DefaultValueGenerator(*args, **kw)
self.default_args = args
self.default_kwargs = kw

super(Instance, self).__init__(default_value, **metadata)
super(Instance, self).__init__(**metadata)

def validate(self, obj, value):
if isinstance(value, self.klass):
Expand All @@ -1075,18 +1054,11 @@ def _resolve_classes(self):
if isinstance(self.klass, py3compat.string_types):
self.klass = self._resolve_string(self.klass)

def get_default_value(self):
"""Instantiate a default value instance.
This is called when the containing HasTraits classes'
:meth:`__new__` method is called to ensure that a unique instance
is created for each HasTraits instance.
"""
dv = self.default_value
if isinstance(dv, DefaultValueGenerator):
return dv.generate(self.klass)
else:
return dv
def make_dynamic_default(self):
if (self.default_args is None) and (self.default_kwargs is None):
return None
return self.klass(*(self.default_args or ()),
**(self.default_kwargs or {}))


class ForwardDeclaredMixin(object):
Expand Down Expand Up @@ -1166,7 +1138,6 @@ def __init__(self, trait_types, **metadata):

def instance_init(self, obj):
for trait_type in self.trait_types:
trait_type.name = self.name
trait_type.this_class = self.this_class
trait_type.instance_init(obj)
super(Union, self).instance_init(obj)
Expand Down Expand Up @@ -1516,7 +1487,6 @@ def __init__(self, trait=None, default_value=None, **metadata):

if is_trait(trait):
self._trait = trait() if isinstance(trait, type) else trait
self._trait.name = 'element'
elif trait is not None:
raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))

Expand Down Expand Up @@ -1710,7 +1680,6 @@ def __init__(self, *traits, **metadata):
self._traits = []
for trait in traits:
t = trait() if isinstance(trait, type) else trait
t.name = 'element'
self._traits.append(t)

if self._traits and default_value is None:
Expand Down Expand Up @@ -1790,14 +1759,10 @@ def __init__(self, trait=None, traits=None, default_value=NoDefaultSpecified,
# Case where a type of TraitType is provided rather than an instance
if is_trait(trait):
self._trait = trait() if isinstance(trait, type) else trait
self._trait.name = 'element'
elif trait is not None:
raise TypeError("`trait` must be a Trait or None, got %s" % repr_type(trait))

self._traits = traits
if traits is not None:
for t in traits.values():
t.name = 'element'

super(Dict, self).__init__(klass=dict, args=args, **metadata)

Expand Down

0 comments on commit 1e89907

Please sign in to comment.