Skip to content

Commit

Permalink
add example 3 & 4 for model parallel tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
levelfour committed Dec 6, 2018
1 parent d7f4b55 commit 3f6a535
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 0 deletions.
Binary file added docs/image/model_parallel/averaging.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/image/model_parallel/parallel_conv.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
66 changes: 66 additions & 0 deletions docs/source/chainermn/model_parallel/example3_parallel_conv.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,68 @@
Example 3: Channel-wise Parallel Convolution
============================================

This is an example to parallelize CNN in channel-wise manner.
This parallelization is useful with a large batch size, or with high resolution images.

.. figure:: ../../../image/model_parallel/parallel_conv.png
:align: center

The basic strategy is

1. pick channels that each process is responsible for
2. apply convolution
3. use ``allgather`` to combine outputs in one tensor

on each process::

class ParallelConvolution2D(chainer.links.Convolution2D):
def __init__(self, comm, in_channels, out_channels, *args, **kwargs):
self.comm = comm
self.in_channels = in_channels
self.out_channels = out_channels
super(ParallelConvolution2D, self).__init__(
self._in_channel_size, self._out_channel_size, *args, **kwargs)

def __call__(self, x):
x = x[:, self._channel_indices, :, :]
y = super(ParallelConvolution2D, self).__call__(x)
ys = chainermn.functions.allgather(self.comm, y)
return F.concat(ys, axis=1)

def _channel_size(self, n_channel):
# Return the size of the corresponding channels.
n_proc = self.comm.size
i_proc = self.comm.rank
return n_channel // n_proc + (1 if i_proc < n_channel % n_proc else 0)

@property
def _in_channel_size(self):
return self._channel_size(self.in_channels)

@property
def _out_channel_size(self):
return self._channel_size(self.out_channels)

@property
def _channel_indices(self):
# Return the indices of the corresponding channel.
indices = np.arange(self.in_channels)
indices = indices[indices % self.comm.size == 0] + self.comm.rank
return [i for i in indices if i < self.in_channels]

``ParallelConvolution2D`` can simply replace with the original ``Convolution2D``.
For the first convolution layer, input images for all processes must be shared.
``MultiNodeIterator`` distributes the same batches to all processes every iteration::

if comm.rank != 0:
train = chainermn.datasets.create_empty_dataset(train)
test = chainermn.datasets.create_empty_dataset(test)

train_iter = chainermn.iterators.create_multi_node_iterator(
chainer.iterators.SerialIterator(train, args.batchsize), comm)
test_iter = chainermn.iterators.create_multi_node_iterator(
chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False),
comm)

An example code for VGG16 parallelization is available `here <https://github.com/chainer/chainer/blob/master/examples/chainermn/parallel_convolution/>`__.
41 changes: 41 additions & 0 deletions docs/source/chainermn/model_parallel/example4_ensemble.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,43 @@
Example 4: Ensemble
===================

Averaging ensemble is one application which collective communications can effectively be applied to.

.. figure:: ../../../image/model_parallel/averaging.png
:align: center

The following wrapper makes model parallel averaging ensemble easier::

class Averaging(chainer.Chain):
def __init__(self, comm, block):
super(Averaging, self).__init__()
self.comm = comm
with self.init_scope():
self.block = block

def __call__(self, x):
y = self.block(x)
if not chainer.config.train:
y = chainermn.functions.allgather(self.comm, y)
y = F.stack(y, axis=0)
y = F.average(y, axis=0)

return y

Then, any links wrapped by ``Averaging`` are ready to be parallelized and averaged::

class Model(chainer.Chain):
def __init__(self, comm):
super(Model, self).__init__()
self.comm = comm
with self.init_scope():
self.l1 = L.Linear(d0, d1)
self.l2 = L.Linear(d1, d2)
self.l3 = Averaging(self.comm, L.Linear(d2, d3))

def __call__(self, x):
h = F.relu(self.l1(x))
h = F.relu(self.l2(h))
y = F.relu(self.l3(h))
return y

0 comments on commit 3f6a535

Please sign in to comment.