Skip to content

Commit

Permalink
Merge pull request chainer#7594 from niboshi/optimizer-multidevice-lo…
Browse files Browse the repository at this point in the history
…ss-scale

Fix multi-device loss scaling
  • Loading branch information
hvy committed Aug 19, 2019
1 parent 0d1243c commit cc7fb31
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
4 changes: 4 additions & 0 deletions chainer/optimizer.py
Expand Up @@ -244,6 +244,10 @@ def update(self, param):

self.t += 1

with chainer.using_device(param.device):
self.__update(param)

def __update(self, param):
try:
param_dtype = param.dtype
except RuntimeError:
Expand Down
24 changes: 16 additions & 8 deletions tests/chainer_tests/test_optimizer.py
Expand Up @@ -530,12 +530,8 @@ def setup_gpu(self, device=None):
self.target.to_gpu(device)
self.optimizer.setup(self.target)

def setup_chainerx(self, orig_xp):
if orig_xp is cuda.cupy:
self.target.to_device('cuda:0')
else:
assert orig_xp is np
self.target.to_device('native:0')
def setup_chainerx(self, device):
self.target.to_device(device)
self.optimizer.setup(self.target)

def check_update(self):
Expand All @@ -558,16 +554,28 @@ def test_update_gpu(self):
self.setup_gpu()
self.check_update()

@attr.multi_gpu(2)
def test_update_multi_gpu(self):
self.setup_gpu(1)
self.check_update()

@attr.chainerx
def test_update_chainerx_cpu(self):
self.setup_chainerx(np)
self.setup_chainerx('native:0')
self.check_update()

@attr.chainerx
@attr.gpu
def test_update_chainerx_gpu(self):
self.setup_gpu()
self.setup_chainerx(cuda.cupy)
self.setup_chainerx('cuda:0')
self.check_update()

@attr.chainerx
@attr.multi_gpu(2)
def test_update_chainerx_multi_gpu(self):
self.setup_gpu(1)
self.setup_chainerx('cuda:1')
self.check_update()


Expand Down

0 comments on commit cc7fb31

Please sign in to comment.