diff --git a/chainer/link.py b/chainer/link.py index 57b3803bb210..bd25006b1022 100644 --- a/chainer/link.py +++ b/chainer/link.py @@ -820,6 +820,13 @@ def __init__(self, *links): for link in links: self.add_link(link) + def __setattr__(self, name, value): + if self.within_init_scope and isinstance(value, Link): + raise TypeError( + 'cannot register a new link' + ' within a "with chainlist.init_scope():" block.') + super(ChainList, self).__setattr__(name, value) + def __getitem__(self, index): """Returns the child at given index. diff --git a/tests/chainer_tests/test_link.py b/tests/chainer_tests/test_link.py index 990f6ac9d121..5baf7414cfbd 100644 --- a/tests/chainer_tests/test_link.py +++ b/tests/chainer_tests/test_link.py @@ -746,6 +746,18 @@ def test_append(self): self.assertIs(self.c2[1], self.l3) self.assertEqual(self.l3.name, '1') + def test_assign_param_in_init_scope(self): + p = chainer.Parameter() + with self.c1.init_scope(): + self.c1.p = p + self.assertIn(p, self.c1.params()) + + def test_assign_link_in_init_scope(self): + l = chainer.Link() + with self.c1.init_scope(): + with self.assertRaises(TypeError): + self.c1.l = l + def test_iter(self): links = list(self.c2) self.assertEqual(2, len(links))