New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add count_params method to Link #3101
Conversation
chainer/link.py
Outdated
@@ -558,6 +558,28 @@ def serialize(self, serializer): | |||
for name in self._persistent: | |||
d[name] = serializer(name, d[name]) | |||
|
|||
@property | |||
def size(self): | |||
"""Count the size of parameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the "number" of parameter is more appropriate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought "number of parameters" could mean the number of "Parameter" objects in the link. So, I used "size of parameters". How do you think about it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you. Then, how about "the total size of parameters" to make the meaning clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's better. I'll modify it.
It seems that this property counts the number of parameters directly attached to the link of interest only and does not include parameters of its children. It depends on use cases which specification is suitable. Could you share me what use case do you have in mind? |
@delta2323 Thanks for the comment.
because some
so it works correctly and I think it's useful to know the "number of parameters" for ensuring the correctness of implementation or comparing the model's heaviness with others. |
@mitmul Sorry, I wrongly remembered the specification of |
Could you add unit tests that check |
I think the name On the other hand, a Do you come up with any? |
@delta2323 OK, I'll add tests for @niboshi I understand what you pointed out. How about |
Keras uses By the way, I would doubt the need of this feature. You can write I think it's better to hear other members' opinions as well. |
Thank you for your comment. Hmm, |
IMO, if it is a routine work to count the number of parameters in a link, it would not be strange that Chainer provides it as a utility function. |
It seems to me that |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs. Thank you for your contributions. |
I prefer |
I think it is a good idea to align the semantics of property names to other frameworks if possible. In that sense, |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs. Thank you for your contributions. |
I'll change the name to |
chainer/link.py
Outdated
def count_params(self): | ||
"""Count the total size of parameters. | ||
|
||
If the link containts uninitialized parameters, this raises warnings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should emphasize that this method counts the number of scalars.
How about something like this?
(Just an example. Please fix as you like)
Counts the total number of parameters.
This method counts the total number of scalar values included in all the Parameter
s held by this link and its descendants.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right. Good suggestion
chainer/link.py
Outdated
def count_params(self): | ||
"""Count the total size of parameters. | ||
|
||
If the link containts uninitialized parameters, this raises warnings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this -> this method
warnings -> a warning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. fixed.
chainer/link.py
Outdated
for name, param in self.namedparams(): | ||
if param.array is None: | ||
warnings.warn( | ||
'{} has not been initialized, so the resulting size will ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{}
-> Parameter '{}'
?
size -> count?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. fixed.
tests/chainer_tests/test_link.py
Outdated
@@ -789,6 +801,16 @@ def test_serialize(self): | |||
mocks['l1'].assert_called_with('x', self.l1.x.data) | |||
mocks['l2'].assert_called_with('x', self.l2.x.data) | |||
|
|||
@unittest.skipUnless( | |||
six.PY3, 'Python2.7 has a bug in catch_warnings, so this test is ' | |||
'skipped for Python2.7') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What kind of bug? How about using testing.assert_warns
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot why I used this. I will try assert_warns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, in Python 2.7, catch_warnings can't count the number of warnings correctly. So the self.assertEqual(len(w), 2)
fails.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. Is this check needed? Warning is printed only once anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove the warning count checking. I removed them.
I think it finished to fix them to follow reviews. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except this comment.
chainer/link.py
Outdated
"""Counts the total number of parameters. | ||
|
||
This method counts the total number of scalar values included in all | ||
the :class:`~chainer.Parameters` held by this link and its descendants. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be:
:class:`~chainer.Parameter`\\ s
@delta2323 LGTM to me except the last comment. Could you merge if LGTM to you? |
@mitmul Please use |
Strictly speaking, this PR does not check |
tests/chainer_tests/test_link.py
Outdated
@@ -468,6 +470,12 @@ def test_update_enabled(self): | |||
self.link.enable_update() | |||
self.assertTrue(self.link.update_enabled) | |||
|
|||
def test_count_params(self): | |||
self.assertEqual(self.link.count_params(), 8) | |||
with warnings.catch_warnings(record=True) as w: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check the number of warnings issued is correct (maybe 1) and the type of the warning is UserWarning
with w
. Otherwise, remove as w
as it is not used.
tests/chainer_tests/test_link.py
Outdated
@@ -789,6 +797,12 @@ def test_serialize(self): | |||
mocks['l1'].assert_called_with('x', self.l1.x.data) | |||
mocks['l2'].assert_called_with('x', self.l2.x.data) | |||
|
|||
def test_count_params(self): | |||
self.assertEqual(self.c1.count_params(), 8) | |||
with warnings.catch_warnings(record=True) as w: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
tests/chainer_tests/test_link.py
Outdated
@@ -1102,6 +1116,12 @@ def test_serialize(self): | |||
mocks['0'].assert_called_with('y', l1.y.data) | |||
mocks['1'].assert_called_with('x', l2.x.data) | |||
|
|||
def test_count_params(self): | |||
self.assertEqual(self.c1.count_params(), 8) | |||
with warnings.catch_warnings(record=True) as w: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
@delta2323 @niboshi |
IMO, you can just use |
Codecov Report
@@ Coverage Diff @@
## master #3101 +/- ##
=========================================
+ Coverage 89.12% 90.6% +1.48%
=========================================
Files 320 320
Lines 21211 21021 -190
=========================================
+ Hits 18905 19047 +142
+ Misses 2306 1974 -332
Continue to review full report at Codecov.
|
Hmm, but I succeeded to count the number of warnings. When two warnings raised, |
@delta2323 Well, I added the test cases for this:
|
I understand it's possible to count the warnings in Python 3, by using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except comments
tests/chainer_tests/test_link.py
Outdated
with warnings.catch_warnings(record=True) as w: | ||
warnings.simplefilter('always') | ||
self.link.count_params() | ||
assert len(w) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can check the condition simply with assert not w
because w is a (empty) list
(cf. https://docs.python.org/3.6/library/warnings.html#warnings.catch_warnings)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But current code is also OK as it can be more intuitive.
tests/chainer_tests/test_link.py
Outdated
with warnings.catch_warnings(record=True) as w: | ||
warnings.simplefilter('always') | ||
self.c2.count_params() | ||
assert len(w) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
tests/chainer_tests/test_link.py
Outdated
with warnings.catch_warnings(record=True) as w: | ||
warnings.simplefilter('always') | ||
self.c2.count_params() | ||
assert len(w) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
@delta2323 Thanks for the comment. I fixed the points. |
Jenkins, test this please. |
@niboshi LGTM. Could you merge this PR if you approve it? |
LGTM! |
This PR adds
size
property toLink
class. It enables to expose the size of parameters in aLink
,Chain
,ChainList
and all objects inherited fromLink
. If aLink
contains uninitialized parameters, a warning message will show up to notify users that the size of them will not be counted in the resulting size.