Skip to content

Commit

Permalink
Merge pull request #2845 from Hakuyume/fix-shift-extension
Browse files Browse the repository at this point in the history
Fix resuming issue of *Shift extensions
  • Loading branch information
niboshi committed Jun 17, 2017
2 parents 8f8c0d4 + fb00320 commit a44f4ec
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 0 deletions.
4 changes: 4 additions & 0 deletions chainer/training/extensions/exponential_shift.py
@@ -1,5 +1,7 @@
from __future__ import division

import numpy as np

from chainer.training import extension


Expand Down Expand Up @@ -70,6 +72,8 @@ def __call__(self, trainer):
def serialize(self, serializer):
self._t = serializer('_t', self._t)
self._last_value = serializer('_last_value', self._last_value)
if isinstance(self._last_value, np.ndarray):
self._last_value = np.asscalar(self._last_value)

def _get_optimizer(self, trainer):
return self._optimizer or trainer.updater.get_optimizer('main')
Expand Down
4 changes: 4 additions & 0 deletions chainer/training/extensions/linear_shift.py
@@ -1,5 +1,7 @@
from __future__ import division

import numpy as np

from chainer.training import extension


Expand Down Expand Up @@ -55,6 +57,8 @@ def __call__(self, trainer):
def serialize(self, serializer):
self._t = serializer('_t', self._t)
self._last_value = serializer('_last_value', self._last_value)
if isinstance(self._last_value, np.ndarray):
self._last_value = np.asscalar(self._last_value)

def _get_optimizer(self, trainer):
return self._optimizer or trainer.updater.get_optimizer('main')
Expand Down
Expand Up @@ -77,6 +77,7 @@ def test_resume(self):

new_extension.initialize(new_trainer)
self.assertEqual(new_optimizer.x, self.optimizer.x)
self.assertIsInstance(new_optimizer.x, float)


class TestExponentialShiftInvalidArgument(unittest.TestCase):
Expand Down
Expand Up @@ -65,6 +65,7 @@ def test_resume(self):

new_extension.initialize(new_trainer)
self.assertEqual(new_optimizer.x, self.optimizer.x)
self.assertIsInstance(new_optimizer.x, float)


testing.run_module(__name__, __file__)

0 comments on commit a44f4ec

Please sign in to comment.