Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add count_params method to Link #3101

Merged
merged 16 commits into from Mar 23, 2018
Merged
26 changes: 26 additions & 0 deletions chainer/link.py
Expand Up @@ -578,6 +578,32 @@ def serialize(self, serializer):
for name in self._persistent:
d[name] = serializer(name, d[name])

def count_params(self):
"""Counts the total number of parameters.

This method counts the total number of scalar values included in all
the :class:`~chainer.Parameter`\\ s held by this link and its
descendants.

If the link containts uninitialized parameters, this method raises a
warning.

Returns:
The total size of parameters (int)

"""

size = 0
for name, param in self.namedparams():
if param.array is None:
warnings.warn(
'Parameter \'{}\' has not been initialized, so the '
'resulting count will not include the number of parameters'
' in it.'.format(name))
continue
size += param.size
return size


class Chain(Link):

Expand Down
44 changes: 44 additions & 0 deletions tests/chainer_tests/test_link.py
@@ -1,5 +1,6 @@
import copy
import unittest
import warnings

import mock
import numpy
Expand Down Expand Up @@ -468,6 +469,21 @@ def test_update_enabled(self):
self.link.enable_update()
self.assertTrue(self.link.update_enabled)

def test_count_params(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
self.link.count_params()
assert len(w) == 2
assert w[0].category is UserWarning
assert self.link.count_params() == 8

self.link.u.initialize((2, 3))
self.link.v.initialize((2, 3))
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
self.link.count_params()
assert not w


class CountParameter(chainer.Parameter):

Expand Down Expand Up @@ -789,6 +805,20 @@ def test_serialize(self):
mocks['l1'].assert_called_with('x', self.l1.x.data)
mocks['l2'].assert_called_with('x', self.l2.x.data)

def test_count_params(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
self.c2.count_params()
assert len(w) == 1
assert w[0].category is UserWarning
assert self.c1.count_params() == 8

self.c2.l3.x.initialize((3,))
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
self.c2.count_params()
assert not w


class TestChainList(unittest.TestCase):

Expand Down Expand Up @@ -1102,6 +1132,20 @@ def test_serialize(self):
mocks['0'].assert_called_with('y', l1.y.data)
mocks['1'].assert_called_with('x', l2.x.data)

def test_count_params(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
self.c2.count_params()
assert len(w) == 1
assert w[0].category is UserWarning
assert self.c1.count_params() == 8

self.c2[0][0].y.initialize((2, 3))
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
self.c2.count_params()
assert not w


@attr.ideep
class TestIntel64(unittest.TestCase):
Expand Down