Skip to content

Commit

Permalink
Keyword-only link argument in GradientMethod.__init__
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Nov 30, 2017
1 parent 9f94cea commit 8336560
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 38 deletions.
31 changes: 28 additions & 3 deletions chainer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from chainer import cuda
from chainer import link as link_module
from chainer import serializer as serializer_module
from chainer.utils import argument
from chainer import variable


Expand Down Expand Up @@ -53,6 +54,8 @@ class Hyperparameter(object):
"""

_parent = None

def __init__(self, parent=None):
self._parent = parent

Expand All @@ -61,6 +64,19 @@ def __getattr__(self, name):
raise AttributeError('_parent is not set up yet')
return getattr(self._parent, name)

def __setattr__(self, name, value):
# If the attribute is not defined as the class attribute of
# `Hyperparameter`, it's assumed to be a hyperparameter.
if not hasattr(Hyperparameter, name):
if not (isinstance(value, (numpy.ndarray, cuda.ndarray))
or (not isinstance(value, (str, bytes))
and numpy.isscalar(value))):
raise TypeError(
'Hyperparameter must be a scalar or an array '
'(name=\'{}\').\n'
'Actual: {}'.format(name, type(value)))
super(Hyperparameter, self).__setattr__(name, value)

def __repr__(self):
d = self.get_dict()
keys = sorted(d.keys())
Expand Down Expand Up @@ -523,12 +539,21 @@ class GradientMethod(Optimizer):
"""

def __init__(self, link=None):
_use_fp32_update = False

def __init__(self, **kwargs):
link, = argument.parse_kwargs(kwargs, ('link', None))
super(GradientMethod, self).__init__()
self.hyperparam = Hyperparameter()
if isinstance(link, link_module.Link):

if link is None:
pass
elif isinstance(link, link_module.Link):
self.setup(link)
self._use_fp32_update = False
else:
raise TypeError(
'link argument must be an instance of chainer.Link.\n'
'Actual: {}'.format(type(link)))

def setup(self, link):
super(GradientMethod, self).setup(link)
Expand Down
6 changes: 4 additions & 2 deletions chainer/optimizers/ada_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from chainer import cuda
from chainer import optimizer
from chainer.utils import argument


_default_hyperparam = optimizer.Hyperparameter()
Expand Down Expand Up @@ -84,8 +85,9 @@ class AdaDelta(optimizer.GradientMethod):
"""

def __init__(self, rho=_default_hyperparam.rho,
eps=_default_hyperparam.eps, model=None):
super(AdaDelta, self).__init__(model)
eps=_default_hyperparam.eps, **kwargs):
link, = argument.parse_kwargs(kwargs, ('link', None))
super(AdaDelta, self).__init__(link=link)
self.hyperparam.rho = rho
self.hyperparam.eps = eps

Expand Down
6 changes: 4 additions & 2 deletions chainer/optimizers/ada_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from chainer import cuda
from chainer import optimizer
from chainer.utils import argument


_default_hyperparam = optimizer.Hyperparameter()
Expand Down Expand Up @@ -75,8 +76,9 @@ class AdaGrad(optimizer.GradientMethod):
"""

def __init__(self, lr=_default_hyperparam.lr,
eps=_default_hyperparam.eps, model=None):
super(AdaGrad, self).__init__(model)
eps=_default_hyperparam.eps, **kwargs):
link, = argument.parse_kwargs(kwargs, ('link', None))
super(AdaGrad, self).__init__(link=link)
self.hyperparam.lr = lr
self.hyperparam.eps = eps

Expand Down
6 changes: 4 additions & 2 deletions chainer/optimizers/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from chainer import cuda
from chainer import optimizer
from chainer.utils import argument


_default_hyperparam = optimizer.Hyperparameter()
Expand Down Expand Up @@ -111,8 +112,9 @@ def __init__(self,
beta1=_default_hyperparam.beta1,
beta2=_default_hyperparam.beta2,
eps=_default_hyperparam.eps,
model=None):
super(Adam, self).__init__(model)
**kwargs):
link, = argument.parse_kwargs(kwargs, ('link', None))
super(Adam, self).__init__(link=link)
self.hyperparam.alpha = alpha
self.hyperparam.beta1 = beta1
self.hyperparam.beta2 = beta2
Expand Down
6 changes: 4 additions & 2 deletions chainer/optimizers/momentum_sgd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from chainer import cuda
from chainer import optimizer
from chainer.utils import argument


_default_hyperparam = optimizer.Hyperparameter()
Expand Down Expand Up @@ -69,8 +70,9 @@ class MomentumSGD(optimizer.GradientMethod):
"""

def __init__(self, lr=_default_hyperparam.lr,
momentum=_default_hyperparam.momentum, model=None):
super(MomentumSGD, self).__init__(model)
momentum=_default_hyperparam.momentum, **kwargs):
link, = argument.parse_kwargs(kwargs, ('link', None))
super(MomentumSGD, self).__init__(link=link)
self.hyperparam.lr = lr
self.hyperparam.momentum = momentum

Expand Down
6 changes: 4 additions & 2 deletions chainer/optimizers/nesterov_ag.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from chainer import cuda
from chainer import optimizer
from chainer.utils import argument


_default_hyperparam = optimizer.Hyperparameter()
Expand Down Expand Up @@ -76,8 +77,9 @@ class NesterovAG(optimizer.GradientMethod):
"""

def __init__(self, lr=_default_hyperparam.lr,
momentum=_default_hyperparam.momentum, model=None):
super(NesterovAG, self).__init__(model)
momentum=_default_hyperparam.momentum, **kwargs):
link, = argument.parse_kwargs(kwargs, ('link', None))
super(NesterovAG, self).__init__(link=link)
self.hyperparam.lr = lr
self.hyperparam.momentum = momentum

Expand Down
6 changes: 4 additions & 2 deletions chainer/optimizers/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from chainer import cuda
from chainer import optimizer
from chainer.utils import argument


_default_hyperparam = optimizer.Hyperparameter()
Expand Down Expand Up @@ -92,8 +93,9 @@ class RMSprop(optimizer.GradientMethod):

def __init__(self, lr=_default_hyperparam.lr,
alpha=_default_hyperparam.alpha, eps=_default_hyperparam.eps,
model=None):
super(RMSprop, self).__init__(model)
**kwargs):
link, = argument.parse_kwargs(kwargs, ('link', None))
super(RMSprop, self).__init__(link=link)
self.hyperparam.lr = lr
self.hyperparam.alpha = alpha
self.hyperparam.eps = eps
Expand Down
6 changes: 4 additions & 2 deletions chainer/optimizers/rmsprop_graves.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from chainer import cuda
from chainer import optimizer
from chainer.utils import argument


_default_hyperparam = optimizer.Hyperparameter()
Expand Down Expand Up @@ -103,8 +104,9 @@ def __init__(self, lr=_default_hyperparam.lr,
alpha=_default_hyperparam.alpha,
momentum=_default_hyperparam.momentum,
eps=_default_hyperparam.eps,
model=None):
super(RMSpropGraves, self).__init__(model)
**kwargs):
link, = argument.parse_kwargs(kwargs, ('link', None))
super(RMSpropGraves, self).__init__(link=link)
self.hyperparam.lr = lr
self.hyperparam.alpha = alpha
self.hyperparam.momentum = momentum
Expand Down
6 changes: 4 additions & 2 deletions chainer/optimizers/sgd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from chainer import cuda
from chainer import optimizer
from chainer.utils import argument


_default_hyperparam = optimizer.Hyperparameter()
Expand Down Expand Up @@ -50,8 +51,9 @@ class SGD(optimizer.GradientMethod):
"""

def __init__(self, lr=_default_hyperparam.lr, model=None):
super(SGD, self).__init__(model)
def __init__(self, lr=_default_hyperparam.lr, **kwargs):
link, = argument.parse_kwargs(kwargs, ('link', None))
super(SGD, self).__init__(link=link)
self.hyperparam.lr = lr

lr = optimizer.HyperparameterProxy('lr')
Expand Down
6 changes: 4 additions & 2 deletions chainer/optimizers/smorms3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from chainer import cuda
from chainer import optimizer
from chainer.utils import argument


_default_hyperparam = optimizer.Hyperparameter()
Expand Down Expand Up @@ -88,8 +89,9 @@ class SMORMS3(optimizer.GradientMethod):
"""

def __init__(self, lr=_default_hyperparam.lr,
eps=_default_hyperparam.eps, model=None):
super(SMORMS3, self).__init__(model)
eps=_default_hyperparam.eps, **kwargs):
link, = argument.parse_kwargs(kwargs, ('link', None))
super(SMORMS3, self).__init__(link=link)
self.hyperparam.lr = lr
self.hyperparam.eps = eps

Expand Down
66 changes: 49 additions & 17 deletions tests/chainer_tests/optimizers_tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,61 @@
optimizers.SMORMS3,
]
}))
class TestOptimizerHyperparameter(unittest.TestCase):
class TestGradientMethodHyperparameter(unittest.TestCase):

def setUp(self):
self.target = chainer.Link()
with self.target.init_scope():
self.target.w = chainer.Parameter()
def create_target(self):
target = chainer.Link()
with target.init_scope():
target.w = chainer.Parameter()
return target

def create(self, *args, **kwargs):
self.optimizer = self.impl(*args, **kwargs)
self.optimizer.setup(self.target)
def get_hyperparam(self, target, name):
return getattr(target.w.update_rule.hyperparam, name)

def get_hyperparam(self, name):
return getattr(self.target.w.update_rule.hyperparam, name)
def check_hyperparams(self, create):
# Retrieve the default hyperparameters of the optimizer.
target = self.create_target()
optimizer = create(target)
default = optimizer.hyperparam.get_dict()

def test_hyperparams(self):
self.create()
default = self.optimizer.hyperparam.get_dict()
for name, default_value in six.iteritems(default):
self.create()
self.assertEqual(self.get_hyperparam(name), default_value)
# Without explicit values, hyperparam of the target link must be
# initialized with the default value.
target = self.create_target()
optimizer = create(target)
assert self.get_hyperparam(target, name) == default_value

# With explicit values, hyperparam of the target link must be
# initialized with that value.
target = self.create_target()
new_value = default_value + 0.1
self.create(**{name: new_value})
self.assertEqual(self.get_hyperparam(name), new_value)
optimizer = create(target, **{name: new_value})
assert self.get_hyperparam(target, name) == new_value

def test_hyperparams_setup_with_init(self):
# Test hyperparameters, using an optimizer whose model is set up by
# __init__ argument.
def create(target, *args, **kwargs):
optimizer = self.impl(*args, link=target, **kwargs)
return optimizer
self.check_hyperparams(create)

def test_hyperparams_separate_setup(self):
# Test hyperparameters, using an optimizer whose model is set up by
# setup() method.
def create(target, *args, **kwargs):
optimizer = self.impl(*args, **kwargs)
optimizer.setup(target)
return optimizer
self.check_hyperparams(create)

def test_link_keyword_only_argument(self):
# Link argument must be specified with keyword (link=).
# This test assumes all the optimizers have the first argument as
# a hyperparameter, thus the link argument is rejected.
target = self.create_target()
with self.assertRaises(TypeError):
self.impl(target)


testing.run_module(__name__, __file__)
42 changes: 42 additions & 0 deletions tests/chainer_tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,48 @@ def test_deep_copy(self):
self.assertIs(child_copy.parent, parent_copy)


@testing.parameterize(*testing.product(
{'value': [
1,
True,
2.1,
2.1j,
np.inf,
np.nan,
np.ones((1, 2), dtype=np.float32),
np.ones((1,), dtype=np.float32),
np.ones((0,), dtype=np.float32),
np.ones((), dtype=np.float32),
]},
))
class TestHyperparameterValidType(unittest.TestCase):

def test_valid_value_type(self):
# Must not raise error
hp = optimizer.Hyperparameter()
hp.newparam = self.value


@testing.parameterize(*testing.product(
{'value': [
None,
'str',
'1',
b'1',
object(),
chainer.Link(),
chainer.Variable(np.ones((), dtype=np.float32)),
]},
))
class TestHyperparameterInvalidType(unittest.TestCase):

def test_invalid_value_type(self):
# Must raise error
hp = optimizer.Hyperparameter()
with self.assertRaises(TypeError):
hp.newparam = self.value


class TestUpdateRule(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit 8336560

Please sign in to comment.