Skip to content

Commit

Permalink
Add copy and deepcopy function to node class. (#800)
Browse files Browse the repository at this point in the history
* Add copy and deepcopy function to node class. This function will only be used by nodes which can be converted to tree.
* Add test
  • Loading branch information
drodarie committed Jan 18, 2024
1 parent 27ab0a9 commit 978fded
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
2 changes: 2 additions & 0 deletions bsb/config/_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
compile_isc,
compile_new,
compile_postnew,
make_copyable,
make_dictable,
make_get_node_name,
make_tree,
Expand Down Expand Up @@ -72,6 +73,7 @@ def node(node_cls, root=False, dynamic=False, pluggable=False):
make_get_node_name(node_cls, root=root)
make_tree(node_cls)
make_dictable(node_cls)
make_copyable(node_cls)

return node_cls

Expand Down
8 changes: 8 additions & 0 deletions bsb/config/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,14 @@ def get_tree(instance):
node_cls.__tree__ = get_tree


def make_copyable(node_cls):
def loc_copy(instance, memo=None):
return type(instance)(instance.__tree__())

node_cls.__copy__ = loc_copy
node_cls.__deepcopy__ = loc_copy


def walk_node_attributes(node):
"""
Walk over all of the child configuration nodes and attributes of ``node``.
Expand Down
34 changes: 34 additions & 0 deletions tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,40 @@ class Test:
self.assertEqual({"statement": "[1, 2, 3]"}, tree["a"])


class TestCopy(unittest.TestCase):
def test_copy(self):
"""
Check copy and deepcopy functions for the nodes.
"""

@config.node
class SubClass:
c = config.attr(
required=False,
default=lambda: np.array([0, 0, 0], dtype=int),
call_default=True,
type=types.ndarray(),
)

@config.root
class MainClass:
a = config.attr(type=SubClass)
b = config.attr(default=5.0)

tab = np.array([1, 2, 3], dtype=int)
instance = MainClass({"a": {"c": tab}, "b": 3.0})
copied = instance.__copy__()
self.assertTrue(id(instance.a) != id(copied.a))
# check that the c arrays elements are equals
self.assertTrue(np.all(instance.a.c == copied.a.c))
self.assertEqual(instance.b, copied.b)
copied = instance.__deepcopy__()
self.assertTrue(id(instance.a) != id(copied.a))
# check that the c arrays elements are equals
self.assertTrue(np.all(instance.a.c == copied.a.c))
self.assertEqual(instance.b, copied.b)


class TestDictScripting(unittest.TestCase):
def test_add(self):
netw = Scaffold()
Expand Down

0 comments on commit 978fded

Please sign in to comment.