This is an example to parallelize CNN in channel-wise manner. This parallelization is useful with large batch size, or with high resolution images.
The basic strategy is
- to pick channels that each process is responsible for
- to apply convolution, and
- to use
allgather
to combine outputs of all channels into a single tensor
on each process. Parallel convolution model implementation could be like this:
class ParallelConvolution2D(chainer.links.Convolution2D):
def __init__(self, comm, in_channels, out_channels, *args, **kwargs):
self.comm = comm
self.in_channels = in_channels
self.out_channels = out_channels
super(ParallelConvolution2D, self).__init__(
self._in_channel_size, self._out_channel_size, *args, **kwargs)
def __call__(self, x):
x = x[:, self._channel_indices, :, :]
y = super(ParallelConvolution2D, self).__call__(x)
ys = chainermn.functions.allgather(self.comm, y)
return F.concat(ys, axis=1)
def _channel_size(self, n_channel):
# Return the size of the corresponding channels.
n_proc = self.comm.size
i_proc = self.comm.rank
return n_channel // n_proc + (1 if i_proc < n_channel % n_proc else 0)
@property
def _in_channel_size(self):
return self._channel_size(self.in_channels)
@property
def _out_channel_size(self):
return self._channel_size(self.out_channels)
@property
def _channel_indices(self):
# Return the indices of the corresponding channel.
indices = np.arange(self.in_channels)
indices = indices[indices % self.comm.size == 0] + self.comm.rank
return [i for i in indices if i < self.in_channels]
where comm
is a ChainerMN communicator (see chainermn-communicator
).
ParallelConvolution2D
can simply replace with the original Convolution2D
. For the first convolution layer, all processes must input the same images to the model. MultiNodeIterator
distributes the same batches to all processes every iteration:
if comm.rank != 0:
train = chainermn.datasets.create_empty_dataset(train)
test = chainermn.datasets.create_empty_dataset(test)
train_iter = chainermn.iterators.create_multi_node_iterator(
chainer.iterators.SerialIterator(train, args.batchsize), comm)
test_iter = chainermn.iterators.create_multi_node_iterator(
chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False),
comm)
An example code with a training script for VGG16 parallelization is available here.