Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add count_params method to Link #3101

Merged
merged 16 commits into from Mar 23, 2018
Merged

Add count_params method to Link #3101

merged 16 commits into from Mar 23, 2018

Conversation

mitmul
Copy link
Member

@mitmul mitmul commented Aug 5, 2017

This PR adds size property to Link class. It enables to expose the size of parameters in a Link, Chain, ChainList and all objects inherited from Link. If a Link contains uninitialized parameters, a warning message will show up to notify users that the size of them will not be counted in the resulting size.

@okuta okuta added the cat:feature Implementation that introduces new interfaces. label Aug 5, 2017
chainer/link.py Outdated
@@ -558,6 +558,28 @@ def serialize(self, serializer):
for name in self._persistent:
d[name] = serializer(name, d[name])

@property
def size(self):
"""Count the size of parameters.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the "number" of parameter is more appropriate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought "number of parameters" could mean the number of "Parameter" objects in the link. So, I used "size of parameters". How do you think about it?

Copy link
Member

@delta2323 delta2323 Aug 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you. Then, how about "the total size of parameters" to make the meaning clearer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's better. I'll modify it.

@delta2323
Copy link
Member

delta2323 commented Aug 6, 2017

It seems that this property counts the number of parameters directly attached to the link of interest only and does not include parameters of its children. It depends on use cases which specification is suitable. Could you share me what use case do you have in mind?

@delta2323 delta2323 self-assigned this Aug 6, 2017
@mitmul
Copy link
Member Author

mitmul commented Aug 7, 2017

@delta2323 Thanks for the comment. Link's namedparams() surely returns the parameters which was directly attached, but Chain's and ChainList's namedparams() will traverse all children, so if size is called from Chain or ChainList, it returns the count of all parameters in the model. For example, If I call this size property from the MLP model defined in the official MNIST example, I got:

In [1]: from train_mnist import MLP

In [2]: model = MLP(100, 10)

In [3]: model.size
/Users/shunta/lib/chainer/chainer/link.py:578: UserWarning: /l1/W has not been initialized, so the resulting size will not include the number of parameters in it.
  'not include the number of parameters in it.'.format(name))
/Users/shunta/lib/chainer/chainer/link.py:578: UserWarning: /l2/W has not been initialized, so the resulting size will not include the number of parameters in it.
  'not include the number of parameters in it.'.format(name))
/Users/shunta/lib/chainer/chainer/link.py:578: UserWarning: /l3/W has not been initialized, so the resulting size will not include the number of parameters in it.
  'not include the number of parameters in it.'.format(name))
Out[3]: 210

because some Links has not been initialized (the above warnings are intended), but after the initial forward pass, I got:

In [6]: model(np.zeros((1, 10), dtype=np.float32))
Out[6]: variable([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])

In [7]: model.size
Out[7]: 12210

so it works correctly and I think it's useful to know the "number of parameters" for ensuring the correctness of implementation or comparing the model's heaviness with others.

@delta2323
Copy link
Member

@mitmul Sorry, I wrongly remembered the specification of namedparams. I agree with this specification as users typically want to know the number of parameters in a whole architecture.

@delta2323
Copy link
Member

delta2323 commented Aug 7, 2017

Could you add unit tests that check Link.param works even for Chain and ChainList. Also, could you add unit tests that the warning is not issued when parameters are initialized via first forward propagation.

@niboshi
Copy link
Member

niboshi commented Aug 7, 2017

I think the name size is too general for this feature. It originates from ndarray.size, which represents the size of a tensor. It's straight-forward that the size of a tensor represents the number of elements. Variable also has size attribute. I think it's ok because a Variable is inherently a tensor.

On the other hand, a Link can be thought of as a collection of Parameters. In this perspective, size should represent the numer of registered Parameters. So I think we need some other name...

Do you come up with any?

@mitmul
Copy link
Member Author

mitmul commented Aug 7, 2017

@delta2323 OK, I'll add tests for Chain and ChainList.

@niboshi I understand what you pointed out. How about count or count_params ?

@niboshi
Copy link
Member

niboshi commented Aug 8, 2017

@mitmul

@niboshi I understand what you pointed out. How about count or count_params ?

Keras uses count_params, so it might be ok, but Keras does not have Parameter class. It might be still confusing in Chainer. I would prefer more descriptive name such as count_param_scalars.

By the way, I would doubt the need of this feature.

You can write sum(param.size or 0 for param in link.params) if we change the behavior of Parameter.size a bit (to return None if it's uninitialized, which is a reasonable change IMO). It's easy and straight-forward for users to write such code.

I think it's better to hear other members' opinions as well.

@mitmul
Copy link
Member Author

mitmul commented Aug 8, 2017

Thank you for your comment.

Hmm, count_param_scalars is too long, so I think count_params() or count is enough. Well, as for the necessity, this is not intended to fill a lack of functionality of Chainer, so it's surely not essential, but just to improve the usability. I've heard that some users wonder how to know the number of trainable parameters in a model in Chainer sometimes because other frameworks provide the way to know it easily. And in the way by sum(...), it's difficult to know which parameters are not counted in the returned value, but Link.count raises warnings for them. So providing count or count_params() for Link is reasonable for this reason I think.

@delta2323
Copy link
Member

IMO, if it is a routine work to count the number of parameters in a link, it would not be strange that Chainer provides it as a utility function.

@delta2323
Copy link
Member

It seems to me that count, count_params could also remind some users of the number of instances of Parameter attached to a link in the same way size does. So far, I could not come up with some word that is simple enough but definitely points out either of the meanings.

@niboshi niboshi added the st:needs-discussion State indicating that discussions are needed before proceeding. label Aug 29, 2017
@stale
Copy link

stale bot commented Nov 8, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale Not updated for a longer period of time. label Nov 8, 2017
@niboshi niboshi removed the stale Not updated for a longer period of time. label Nov 9, 2017
@delta2323
Copy link
Member

@mitmul @niboshi Do you have some idea how to proceed this issue?

@niboshi
Copy link
Member

niboshi commented Nov 13, 2017

I prefer count_params among the names mentioned above. How do you think?

@delta2323
Copy link
Member

delta2323 commented Nov 14, 2017

I think it is a good idea to align the semantics of property names to other frameworks if possible. In that sense, count_params is OK (It seems PyTorch and Caffe do not have this functionality. We list up parameters inside a model and sum them up explicitly).

@stale
Copy link

stale bot commented Feb 12, 2018

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale Not updated for a longer period of time. label Feb 12, 2018
@delta2323
Copy link
Member

@mitmul @niboshi How do you think of it?

@stale stale bot removed the stale Not updated for a longer period of time. label Feb 12, 2018
@mitmul
Copy link
Member Author

mitmul commented Feb 26, 2018

I'll change the name to count_params

chainer/link.py Outdated
def count_params(self):
"""Count the total size of parameters.

If the link containts uninitialized parameters, this raises warnings.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should emphasize that this method counts the number of scalars.
How about something like this?
(Just an example. Please fix as you like)

Counts the total number of parameters.

This method counts the total number of scalar values included in all the Parameters held by this link and its descendants.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. Good suggestion

chainer/link.py Outdated
def count_params(self):
"""Count the total size of parameters.

If the link containts uninitialized parameters, this raises warnings.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this -> this method
warnings -> a warning

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. fixed.

chainer/link.py Outdated
for name, param in self.namedparams():
if param.array is None:
warnings.warn(
'{} has not been initialized, so the resulting size will '
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{} -> Parameter '{}' ?
size -> count?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. fixed.

@@ -789,6 +801,16 @@ def test_serialize(self):
mocks['l1'].assert_called_with('x', self.l1.x.data)
mocks['l2'].assert_called_with('x', self.l2.x.data)

@unittest.skipUnless(
six.PY3, 'Python2.7 has a bug in catch_warnings, so this test is '
'skipped for Python2.7')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of bug? How about using testing.assert_warns?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot why I used this. I will try assert_warns

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, in Python 2.7, catch_warnings can't count the number of warnings correctly. So the self.assertEqual(len(w), 2) fails.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Is this check needed? Warning is printed only once anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the warning count checking. I removed them.

@mitmul
Copy link
Member Author

mitmul commented Feb 27, 2018

I think it finished to fix them to follow reviews.

Copy link
Member

@niboshi niboshi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except this comment.

chainer/link.py Outdated
"""Counts the total number of parameters.

This method counts the total number of scalar values included in all
the :class:`~chainer.Parameters` held by this link and its descendants.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be:

:class:`~chainer.Parameter`\\ s

@niboshi
Copy link
Member

niboshi commented Feb 27, 2018

@delta2323 LGTM to me except the last comment.

Could you merge if LGTM to you?

@niboshi niboshi removed the st:needs-discussion State indicating that discussions are needed before proceeding. label Feb 27, 2018
@niboshi
Copy link
Member

niboshi commented Feb 28, 2018

@mitmul Please use testing.assert_warns.

@delta2323
Copy link
Member

Strictly speaking, this PR does not check count_params does not throw warning when all parameters are initialized. So could you add such test cases?

@@ -468,6 +470,12 @@ def test_update_enabled(self):
self.link.enable_update()
self.assertTrue(self.link.update_enabled)

def test_count_params(self):
self.assertEqual(self.link.count_params(), 8)
with warnings.catch_warnings(record=True) as w:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the number of warnings issued is correct (maybe 1) and the type of the warning is UserWarning with w. Otherwise, remove as w as it is not used.

@@ -789,6 +797,12 @@ def test_serialize(self):
mocks['l1'].assert_called_with('x', self.l1.x.data)
mocks['l2'].assert_called_with('x', self.l2.x.data)

def test_count_params(self):
self.assertEqual(self.c1.count_params(), 8)
with warnings.catch_warnings(record=True) as w:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -1102,6 +1116,12 @@ def test_serialize(self):
mocks['0'].assert_called_with('y', l1.y.data)
mocks['1'].assert_called_with('x', l2.x.data)

def test_count_params(self):
self.assertEqual(self.c1.count_params(), 8)
with warnings.catch_warnings(record=True) as w:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@mitmul
Copy link
Member Author

mitmul commented Mar 15, 2018

@delta2323 @niboshi assert_warns seems not to return any values when it calls yield, so we can't receive any information about the number of warning raised. https://github.com/chainer/chainer/blob/master/chainer/testing/helper.py#L53
I think we need to use catch_warnings directly OR stop checking the number of warnings. What should I do?

@niboshi
Copy link
Member

niboshi commented Mar 15, 2018

IMO, you can just use testing.assert_warns, and ignore the number of warnings.

@codecov-io
Copy link

codecov-io commented Mar 15, 2018

Codecov Report

Merging #3101 into master will increase coverage by 1.48%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff            @@
##           master   #3101      +/-   ##
=========================================
+ Coverage   89.12%   90.6%   +1.48%     
=========================================
  Files         320     320              
  Lines       21211   21021     -190     
=========================================
+ Hits        18905   19047     +142     
+ Misses       2306    1974     -332
Impacted Files Coverage Δ
chainer/link.py 93.75% <100%> (+3.95%) ⬆️
chainer/datasets/svhn.py 33.33% <0%> (-3.04%) ⬇️
chainer/function_hooks/debug_print.py 97.36% <0%> (-2.64%) ⬇️
chainer/backends/cuda.py 76.41% <0%> (-2.44%) ⬇️
chainer/optimizer.py 81.81% <0%> (-0.7%) ⬇️
chainer/training/extensions/progress_bar.py 15.71% <0%> (-0.47%) ⬇️
chainer/links/connection/n_step_rnn.py 98.21% <0%> (ø) ⬆️
chainer/dataset/convert.py 49.56% <0%> (ø) ⬆️
chainer/links/connection/n_step_lstm.py 100% <0%> (ø) ⬆️
chainer/functions/connection/convolution_nd.py 100% <0%> (ø) ⬆️
... and 25 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 1759406...9860b9d. Read the comment docs.

@mitmul
Copy link
Member Author

mitmul commented Mar 15, 2018

Hmm, but I succeeded to count the number of warnings. When two warnings raised, len(w) returns 2 if I used catch_warnings. But I guess this doesn't work in Python2.

@mitmul
Copy link
Member Author

mitmul commented Mar 15, 2018

@delta2323 Well, I added the test cases for this:

Strictly speaking, this PR does not check count_params does not throw warning when all parameters are initialized. So could you add such test cases?

@niboshi
Copy link
Member

niboshi commented Mar 15, 2018

I understand it's possible to count the warnings in Python 3, by using catch_warnings.
What I wonder is if it matters.
(No matter how many warnings is raised, the observable outcome to the users is the same.)

Copy link
Member

@delta2323 delta2323 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except comments

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
self.link.count_params()
assert len(w) == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check the condition simply with assert not w because w is a (empty) list
(cf. https://docs.python.org/3.6/library/warnings.html#warnings.catch_warnings)

Copy link
Member

@delta2323 delta2323 Mar 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But current code is also OK as it can be more intuitive.

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
self.c2.count_params()
assert len(w) == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
self.c2.count_params()
assert len(w) == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@mitmul
Copy link
Member Author

mitmul commented Mar 15, 2018

@delta2323 Thanks for the comment. I fixed the points.

@delta2323
Copy link
Member

Jenkins, test this please.

@delta2323
Copy link
Member

@niboshi LGTM. Could you merge this PR if you approve it?

@niboshi
Copy link
Member

niboshi commented Mar 23, 2018

LGTM!

@niboshi niboshi merged commit ecb6f2c into chainer:master Mar 23, 2018
@kmaehashi kmaehashi added this to the v5.0.0a1 milestone Apr 17, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:feature Implementation that introduces new interfaces.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants