Skip to content

Commit

Permalink
Merge pull request chainer#7718 from shu65/fix_mnbn_bug
Browse files Browse the repository at this point in the history
Fix create_mnbn_model() bug
  • Loading branch information
kuenishi committed Aug 1, 2019
1 parent 3810582 commit ffc84b4
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 13 deletions.
7 changes: 3 additions & 4 deletions chainermn/links/create_mnbn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ def create_mnbn_model(link, comm, communication_backend='auto'):
new_link.__dict__[name] = new_child
return new_link
elif isinstance(link, chainer.Sequential):
new_children = [
create_mnbn_model(l, comm, communication_backend) for l in link]
new_link = copy.deepcopy(link)
for i, new_child in enumerate(new_children):
new_link._layers[i] = new_child
for i, l in enumerate(link):
new_l = create_mnbn_model(l, comm, communication_backend)
new_link[i] = new_l
return new_link
elif isinstance(link, chainer.ChainList):
new_children = [
Expand Down
60 changes: 51 additions & 9 deletions tests/chainermn_tests/links_tests/test_create_mnbn_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest

import chainer
import chainer.testing
import chainer.testing.attr
import chainermn


Expand Down Expand Up @@ -35,7 +37,7 @@ class TestCreateMnBnModel(unittest.TestCase):
def setUp(self):
self.communicator = chainermn.create_communicator('naive')

def test_create_mnbn_model_chain(self):
def check_create_mnbn_model_chain(self, gpu):
model = BnChain(3)
mnbn_model = chainermn.links.create_mnbn_model(model,
self.communicator)
Expand All @@ -44,8 +46,16 @@ def test_create_mnbn_model_chain(self):
self.assertTrue(
isinstance(mnbn_model.bn,
chainermn.links.MultiNodeBatchNormalization))
if gpu:
device_id = self.communicator.intra_rank
mnbn_model.to_gpu(device=device_id)
else:
device_id = -1
with chainer.using_device(mnbn_model.device):
x = mnbn_model.xp.zeros((1, 1, 1, 1))
mnbn_model(x)

def test_create_mnbn_model_chain_list(self):
def check_create_mnbn_model_chain_list(self, gpu):
model = BnChainList(3)
mnbn_model = chainermn.links.create_mnbn_model(model,
self.communicator)
Expand All @@ -54,8 +64,16 @@ def test_create_mnbn_model_chain_list(self):
self.assertTrue(
isinstance(mnbn_model[1],
chainermn.links.MultiNodeBatchNormalization))
if gpu:
device_id = self.communicator.intra_rank
mnbn_model.to_gpu(device=device_id)
else:
device_id = -1
with chainer.using_device(mnbn_model.device):
x = mnbn_model.xp.zeros((1, 1, 1, 1))
mnbn_model(x)

def test_create_mnbn_model_sequential(self):
def check_create_mnbn_model_sequential(self, gpu):
size = 3
model = chainer.Sequential(
chainer.links.Convolution2D(
Expand All @@ -65,9 +83,33 @@ def test_create_mnbn_model_sequential(self):
)
mnbn_model = chainermn.links.create_mnbn_model(model,
self.communicator)
self.assertTrue(isinstance(mnbn_model[0],
chainer.links.Convolution2D))
self.assertTrue(
isinstance(mnbn_model[1],
chainermn.links.MultiNodeBatchNormalization))
self.assertTrue(mnbn_model[2] == chainer.functions.relu)

if gpu:
device_id = self.communicator.intra_rank
mnbn_model.to_gpu(device=device_id)
else:
device_id = -1
with chainer.using_device(mnbn_model.device):
x = mnbn_model.xp.zeros((1, 1, 1, 1))
mnbn_model(x)

def test_create_mnbn_model_chain_cpu(self):
self.check_create_mnbn_model_chain(gpu=False)

def test_create_mnbn_model_chain_list_cpu(self):
self.check_create_mnbn_model_chain_list(gpu=False)

def test_create_mnbn_model_sequential_cpu(self):
self.check_create_mnbn_model_sequential(gpu=False)

@chainer.testing.attr.gpu
def test_create_mnbn_model_chain_gpu(self):
self.check_create_mnbn_model_chain(gpu=True)

@chainer.testing.attr.gpu
def test_create_mnbn_model_chain_list_gpu(self):
self.check_create_mnbn_model_chain_list(gpu=True)

@chainer.testing.attr.gpu
def test_create_mnbn_model_sequential_gpu(self):
self.check_create_mnbn_model_sequential(gpu=True)

0 comments on commit ffc84b4

Please sign in to comment.