Skip to content

Commit

Permalink
Merge pull request #7787 from shu65/bp-7600-opt
Browse files Browse the repository at this point in the history
[backport] Avoid unload module call in PureNcclCommunicator
  • Loading branch information
kuenishi committed Jul 23, 2019
2 parents 3356743 + 39e057a commit f6fbbfc
Showing 1 changed file with 3 additions and 18 deletions.
21 changes: 3 additions & 18 deletions chainermn/communicators/pure_nccl_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def multi_node_mean_nccl(self, gpu_buffer_a, gpu_buffer_b,
self.nccl_comm.allReduce(gpu_buffer_a.ptr(),
gpu_buffer_b.ptr(), n_elems,
type_id, nccl.NCCL_SUM, stream.ptr)
div_by_size = chainer.cuda.cupy.ElementwiseKernel(
div_by_size = chainer.cuda.elementwise(
'{} x'.format(dtype.name),
'{} y'.format(dtype.name),
'y = x*(1.0/{})'.format(self.size), 'div_by_size')
Expand All @@ -186,21 +186,6 @@ def multi_node_mean_nccl(self, gpu_buffer_a, gpu_buffer_b,
self.ensure_all_finite(gpu_buffer_a.array(n_elems, dtype=dtype))


def _get_converting_kernel(src_dtype, dst_dtype, kernel_name):
return chainer.cuda.cupy.ElementwiseKernel(
'{} x'.format(src_dtype.name),
'{} y'.format(dst_dtype.name),
'y = x', kernel_name)


def _get_param_data_dtype(param):
return param.data.dtype


def _get_param_grad_dtype(param):
return param.grad.dtype


class _ParamsData(object):
def __init__(self, params, attr_name, zero_fill):
n_params = len(params)
Expand Down Expand Up @@ -256,7 +241,7 @@ def _batched_unpack_params(params_data, buffer, dtype):


def _cupy_batched_pack_params():
return chainer.cuda.cupy.RawKernel(r'''
return chainer.cuda.raw(r'''
#include <cupy/carray.cuh>
#define NCCL_FLOAT16 6
#define NCCL_FLOAT32 7
Expand Down Expand Up @@ -309,7 +294,7 @@ def _cupy_batched_pack_params():


def _cupy_batched_unpack_params():
return chainer.cuda.cupy.RawKernel(r'''
return chainer.cuda.raw(r'''
#include <cupy/carray.cuh>
#define NCCL_FLOAT16 6
#define NCCL_FLOAT32 7
Expand Down

0 comments on commit f6fbbfc

Please sign in to comment.