Skip to content

Commit

Permalink
Merge pull request #4483 from okuta/reduce-user-warnings
Browse files Browse the repository at this point in the history
Reduce UserWarning
  • Loading branch information
bkvogel committed Mar 19, 2018
2 parents 3a44f37 + 23b9264 commit 6b4f388
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 9 deletions.
12 changes: 7 additions & 5 deletions tests/chainer_tests/test_optimizer.py
Expand Up @@ -200,23 +200,23 @@ def setUp(self):
np.arange(3, -3, -1, dtype=np.float32).reshape(2, 3))

def test_add_hook(self):
h1 = mock.MagicMock()
h1 = mock.MagicMock(timing='pre')
h1.call_for_each_param = False
self.optimizer.setup(self.target)
self.optimizer.add_hook(h1, 'h1')
self.optimizer.call_hooks()
h1.assert_called_with(self.optimizer)

def test_add_hook_call_for_each_param(self):
h1 = mock.MagicMock()
h1 = mock.MagicMock(timing='pre')
h1.call_for_each_param = True
self.optimizer.setup(self.target)
self.optimizer.add_hook(h1, 'h1')
self.optimizer.call_hooks()
h1.assert_called_with(self.target.param.update_rule, self.target.param)

def test_remove_hook(self):
h1 = mock.MagicMock()
h1 = mock.MagicMock(timing='pre')
self.optimizer.setup(self.target)
self.optimizer.add_hook(h1, 'h1')
self.optimizer.remove_hook('h1')
Expand All @@ -225,9 +225,9 @@ def test_remove_hook(self):

def test_duplicated_hook(self):
self.optimizer.setup(self.target)
self.optimizer.add_hook(lambda s: None, 'h1')
self.optimizer.add_hook(lambda s: None, 'h1', timing='pre')
with self.assertRaises(KeyError):
self.optimizer.add_hook(lambda s: None, 'h1')
self.optimizer.add_hook(lambda s: None, 'h1', timing='pre')

def test_invalid_hook(self):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -382,6 +382,7 @@ def create_update_rule(self):
class DummyHook(object):

name = 'Dummy'
timing = 'pre'

def __init__(self, test):
self.test = test
Expand All @@ -395,6 +396,7 @@ def __call__(self, opt):
class CleargradHook(object):

name = 'Cleargrad'
timing = 'pre'

def __init__(self, _):
pass
Expand Down
Expand Up @@ -110,7 +110,8 @@ def test_resumed_trigger_backward_compat(self):
np.savez(f, dummy=0)

trigger = training.triggers.IntervalTrigger(*self.interval)
serializers.load_npz(f.name, trigger)
with testing.assert_warns(UserWarning):
serializers.load_npz(f.name, trigger)
for expected in self.expected[self.resume:]:
trainer.updater.update()
self.assertEqual(trigger(trainer), expected)
Expand Down
Expand Up @@ -128,7 +128,8 @@ def test_resumed_trigger_backward_compat(self):
np.savez(f, dummy=0)

trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
serializers.load_npz(f.name, trigger)
with testing.assert_warns(UserWarning):
serializers.load_npz(f.name, trigger)
for expected in self.expected[self.resume:]:
trainer.updater.update()
self.assertEqual(trigger(trainer), expected)
Expand Down
Expand Up @@ -181,8 +181,10 @@ def test_update_uses_raw_array(self):
dataset, len(devices))]
optimizer = chainer.optimizers.SGD(lr=1.0)
optimizer.setup(model)
updater = mpu.MultiprocessParallelUpdater(
iters, optimizer, devices=devices)

with testing.assert_warns(UserWarning):
updater = mpu.MultiprocessParallelUpdater(
iters, optimizer, devices=devices)
updater.update()

self.assertEqual(model.call_called, 1)
Expand Down

0 comments on commit 6b4f388

Please sign in to comment.