Skip to content

Commit

Permalink
Fix cloning issue with container traits (#1624)
Browse files Browse the repository at this point in the history
This PR makes a shallow fix for problems with `clone_traits` applied to `List`, `Dict` and `Set` traits. It doesn't try to touch deeper issues of disconnection of `TraitsListObject`, `TraitDictObject` and friends from their owning `HasTraits` object.

Closes #1622. The cause of that issue is that we were [using a function](https://github.com/enthought/traits/blob/c8bd6e5f332d44b512b79f0bee3cb814b9125352/traits/trait_list_object.py#L815) `lambda: None` for `object` where a `HasTraits` object was expected. Inside `TraitListObject` we then [take a weakref](https://github.com/enthought/traits/blob/c8bd6e5f332d44b512b79f0bee3cb814b9125352/traits/trait_list_object.py#L572) to that function. In most cases, the `lambda` function has no other references to it, so it's garbage collected immediately and when the weakref is dereferenced, it returns `None`. But in the deepcopy case the weakref target is kept alive for long enough that we try to use the actual `lambda: None` function as a `HasTraits` object. The solution is to allow and special-case an object of `None` in the `TraitListObject` constructor.

Co-authored-by: Steve Allen <sallen@enthought.com>
  • Loading branch information
mdickinson and sallenEnth committed Mar 31, 2022
1 parent c8bd6e5 commit a307f57
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 18 deletions.
19 changes: 19 additions & 0 deletions traits/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DelegatesTo,
Dict,
Either,
Enum,
Instance,
Int,
List,
Expand Down Expand Up @@ -309,6 +310,24 @@ class A(HasTraits):
with self.assertRaises(ZeroDivisionError):
a.bar = "foo"

def test_clone_list_of_enum_trait(self):
# Regression test for enthought/traits#1622.

class Order(HasTraits):
menu = List(Str)
selection = List(Enum(values="menu"))

order = Order(menu=["fish"], selection=["fish"])
clone = order.clone_traits()

self.assertEqual(clone.selection, ["fish"])

order.selection.append('fish')
self.assertEqual(clone.selection, ['fish'])

with self.assertRaises(TraitError):
clone.selection.append("bouillabaisse")


class NestedContainerClass(HasTraits):
# Used in regression test for changes to nested containers
Expand Down
12 changes: 12 additions & 0 deletions traits/tests/test_trait_dict_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,15 @@ class DifferentName(TraitDictEvent):
differnt_name_subclass = DifferentName()
self.assertEqual(desired_repr, str(differnt_name_subclass))
self.assertEqual(desired_repr, repr(differnt_name_subclass))

def test_disconnected_dict(self):
# Objects that are disconnected from their HasTraits "owner" can arise
# as a result of clone_traits operations, or of serialization and
# deserialization.
disconnected = TraitDictObject(
trait=Dict(Str, Str),
object=None,
name="foo",
value={},
)
self.assertEqual(disconnected.object(), None)
12 changes: 12 additions & 0 deletions traits/tests/test_trait_list_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,3 +1457,15 @@ def test_dead_object_reference(self):
self.assertEqual(list_object, [1, 2, 3, 4, 5])
with self.assertRaises(TraitError):
list_object.append(4)

def test_disconnected_list(self):
# Objects that are disconnected from their HasTraits "owner" can arise
# as a result of clone_traits operations, or of serialization and
# deserialization.
disconnected = TraitListObject(
trait=List(Int),
object=None,
name="foo",
value=[1, 2, 3],
)
self.assertEqual(disconnected.object(), None)
14 changes: 13 additions & 1 deletion traits/tests/test_trait_set_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from traits.api import HasTraits, Set, Str
from traits.trait_base import _validate_everything
from traits.trait_errors import TraitError
from traits.trait_set_object import TraitSet, TraitSetEvent
from traits.trait_set_object import TraitSet, TraitSetEvent, TraitSetObject
from traits.trait_types import _validate_int


Expand Down Expand Up @@ -517,6 +517,18 @@ class Foo(HasTraits):
# then
notifier.assert_not_called()

def test_disconnected_set(self):
# Objects that are disconnected from their HasTraits "owner" can arise
# as a result of clone_traits operations, or of serialization and
# deserialization.
disconnected = TraitSetObject(
trait=Set(Str),
object=None,
name="foo",
value=set(),
)
self.assertEqual(disconnected.object(), None)


class TestTraitSetEvent(unittest.TestCase):

Expand Down
15 changes: 8 additions & 7 deletions traits/trait_dict_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,9 @@ class TraitDictObject(TraitDict):
trait : CTrait instance
The CTrait instance associated with the attribute that this dict
has been set to.
object : HasTraits instance
The HasTraits instance that the dict has been set as an attribute for.
object : HasTraits
The object this dict belongs to. Can also be None in cases where the
dict has been disconnected from its HasTraits parent.
name : str
The name of the attribute on the object.
value : dict
Expand All @@ -426,9 +427,9 @@ class TraitDictObject(TraitDict):
trait : CTrait instance
The CTrait instance associated with the attribute that this dict
has been set to.
object : weak reference to a HasTraits instance
A weak reference to a HasTraits instance that the dict has been set
as an attribute for.
object : callable
A callable that when called with no arguments returns the HasTraits
object that this dict belongs to, or None if there is no such object.
name : str
The name of the attribute on the object.
name_items : str
Expand All @@ -438,7 +439,7 @@ class TraitDictObject(TraitDict):

def __init__(self, trait, object, name, value):
self.trait = trait
self.object = ref(object)
self.object = (lambda: None) if object is None else ref(object)
self.name = name
self.name_items = None
if trait.has_items:
Expand Down Expand Up @@ -585,7 +586,7 @@ def __deepcopy__(self, memo):
"""
result = TraitDictObject(
self.trait,
lambda: None,
None,
self.name,
dict(copy.deepcopy(x, memo) for x in self.items()),
)
Expand Down
12 changes: 7 additions & 5 deletions traits/trait_list_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,8 @@ class TraitListObject(TraitList):
trait : CTrait
The trait that the list has been assigned to.
object : HasTraits
The object the list belongs to.
The object this list belongs to. Can also be None in cases where the
list has been disconnected from its HasTraits parent.
name : str
The name of the trait on the object.
value : iterable
Expand All @@ -558,8 +559,9 @@ class TraitListObject(TraitList):
----------
trait : CTrait
The trait that the list has been assigned to.
object : HasTraits
The object the list belongs to.
object : callable
A callable that when called with no arguments returns the HasTraits
object that this list belongs to, or None if there is no such object.
name : str
The name of the trait on the object.
value : iterable
Expand All @@ -569,7 +571,7 @@ class TraitListObject(TraitList):
def __init__(self, trait, object, name, value):

self.trait = trait
self.object = ref(object)
self.object = (lambda: None) if object is None else ref(object)
self.name = name
self.name_items = None
if trait.has_items:
Expand Down Expand Up @@ -812,7 +814,7 @@ def __deepcopy__(self, memo):
"""
return TraitListObject(
self.trait,
lambda: None,
None,
self.name,
[copy.deepcopy(x, memo) for x in self],
)
Expand Down
12 changes: 7 additions & 5 deletions traits/trait_set_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,8 @@ class TraitSetObject(TraitSet):
trait : CTrait
The trait that the set has been assigned to.
object : HasTraits
The object the set belongs to.
The object this set belongs to. Can also be None in cases where the
set has been disconnected from its HasTraits parent.
name : str
The name of the trait on the object.
value : iterable
Expand All @@ -461,8 +462,9 @@ class TraitSetObject(TraitSet):
----------
trait : CTrait
The trait that the set has been assigned to.
object : HasTraits
The object the set belongs to.
object : callable
A callable that when called with no arguments returns the HasTraits
object that this set belongs to, or None if there is no such object.
name : str
The name of the trait on the object.
value : iterable
Expand All @@ -472,7 +474,7 @@ class TraitSetObject(TraitSet):
def __init__(self, trait, object, name, value):

self.trait = trait
self.object = ref(object)
self.object = (lambda: None) if object is None else ref(object)
self.name = name
self.name_items = None
if trait.has_items:
Expand Down Expand Up @@ -560,7 +562,7 @@ def __deepcopy__(self, memo):

result = TraitSetObject(
self.trait,
lambda: None,
None,
self.name,
{copy.deepcopy(x, memo) for x in self},
)
Expand Down

0 comments on commit a307f57

Please sign in to comment.