Skip to content

Commit

Permalink
Merge pull request #7942 from emcastillo/dist-array-shino
Browse files Browse the repository at this point in the history
[takeover] Add distributed ndarray
  • Loading branch information
asi1024 committed Nov 7, 2023
2 parents a29335a + 58fde4f commit f3536c6
Show file tree
Hide file tree
Showing 20 changed files with 3,305 additions and 320 deletions.
2 changes: 1 addition & 1 deletion cupy/_core/_kernel.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ cdef list _get_out_args_with_params(
list out_args, tuple out_types,
const shape_t& out_shape, tuple out_params, bint is_size_specified)

cdef _check_peer_access(_ndarray_base arr, int device_id)
cpdef _check_peer_access(_ndarray_base arr, int device_id)

cdef list _preprocess_args(int dev_id, args, bint use_c_scalar)

Expand Down
2 changes: 1 addition & 1 deletion cupy/_core/_kernel.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ cdef inline int _get_kind_score(int kind):


@cython.profile(False)
cdef inline _check_peer_access(_ndarray_base arr, int device_id):
cpdef inline _check_peer_access(_ndarray_base arr, int device_id):
if arr.data.device_id == device_id:
return

Expand Down
4 changes: 4 additions & 0 deletions cupy/_core/_reduction.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,10 @@ cdef class _SimpleReductionKernel(_AbstractReductionKernel):
def __call__(self, object a, axis=None, dtype=None, _ndarray_base out=None,
bint keepdims=False):

if hasattr(a, '__cupy_override_reduction_kernel__'):
return a.__cupy_override_reduction_kernel__(
self, axis, dtype, out, keepdims)

cdef _ndarray_base arr

if isinstance(a, _ndarray_base):
Expand Down
181 changes: 0 additions & 181 deletions cupyx/distributed/_array.py

This file was deleted.

57 changes: 29 additions & 28 deletions cupyx/distributed/_nccl_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@
_nccl_ops = {}


def _get_nccl_dtype_and_count(array, count=None):
dtype = array.dtype.char
if dtype not in _nccl_dtypes:
raise TypeError(f'Unknown dtype {array.dtype} for NCCL')
nccl_dtype = _nccl_dtypes[dtype]
if count is None:
count = array.size
if dtype in 'FD':
return nccl_dtype, 2 * count
return nccl_dtype, count


class NCCLBackend(_Backend):
"""Interface that uses NVIDIA's NCCL to perform communications.
Expand Down Expand Up @@ -104,17 +116,6 @@ def _check_contiguous(self, array):
raise RuntimeError(
'NCCL requires arrays to be either c- or f-contiguous')

def _get_nccl_dtype_and_count(self, array, count=None):
dtype = array.dtype.char
if dtype not in _nccl_dtypes:
raise TypeError(f'Unknown dtype {array.dtype} for NCCL')
nccl_dtype = _nccl_dtypes[dtype]
if count is None:
count = array.size
if dtype in 'FD':
return nccl_dtype, 2 * count
return nccl_dtype, count

def _get_stream(self, stream):
if stream is None:
stream = cupy.cuda.stream.get_current_stream()
Expand Down Expand Up @@ -315,7 +316,7 @@ def all_reduce(cls, comm, in_array, out_array, op='sum', stream=None):
comm._check_contiguous(in_array)
comm._check_contiguous(out_array)
stream = comm._get_stream(stream)
dtype, count = comm._get_nccl_dtype_and_count(in_array)
dtype, count = _get_nccl_dtype_and_count(in_array)
op = comm._get_op(op, in_array.dtype.char)
comm._comm.allReduce(
in_array.data.ptr, out_array.data.ptr, count, dtype, op, stream)
Expand All @@ -326,7 +327,7 @@ def reduce(cls, comm, in_array, out_array, root=0, op='sum', stream=None):
if comm.rank == root:
comm._check_contiguous(out_array)
stream = comm._get_stream(stream)
dtype, count = comm._get_nccl_dtype_and_count(in_array)
dtype, count = _get_nccl_dtype_and_count(in_array)
op = comm._get_op(op, in_array.dtype.char)
comm._comm.reduce(
in_array.data.ptr, out_array.data.ptr,
Expand All @@ -336,7 +337,7 @@ def reduce(cls, comm, in_array, out_array, root=0, op='sum', stream=None):
def broadcast(cls, comm, in_out_array, root=0, stream=None):
comm._check_contiguous(in_out_array)
stream = comm._get_stream(stream)
dtype, count = comm._get_nccl_dtype_and_count(in_out_array)
dtype, count = _get_nccl_dtype_and_count(in_out_array)
comm._comm.broadcast(
in_out_array.data.ptr, in_out_array.data.ptr,
count, dtype, root, stream)
Expand All @@ -347,7 +348,7 @@ def reduce_scatter(
comm._check_contiguous(in_array)
comm._check_contiguous(out_array)
stream = comm._get_stream(stream)
dtype, count = comm._get_nccl_dtype_and_count(in_array, count)
dtype, count = _get_nccl_dtype_and_count(in_array, count)
op = comm._get_op(op, in_array.dtype.char)
comm._comm.reduceScatter(
in_array.data.ptr, out_array.data.ptr, count, dtype, op, stream)
Expand All @@ -357,15 +358,15 @@ def all_gather(cls, comm, in_array, out_array, count, stream=None):
comm._check_contiguous(in_array)
comm._check_contiguous(out_array)
stream = comm._get_stream(stream)
dtype, count = comm._get_nccl_dtype_and_count(in_array, count)
dtype, count = _get_nccl_dtype_and_count(in_array, count)
comm._comm.allGather(
in_array.data.ptr, out_array.data.ptr, count, dtype, stream)

@classmethod
def send(cls, comm, array, peer, stream=None):
comm._check_contiguous(array)
stream = comm._get_stream(stream)
dtype, count = comm._get_nccl_dtype_and_count(array)
dtype, count = _get_nccl_dtype_and_count(array)
cls._send(comm, array, peer, dtype, count, stream)

@classmethod
Expand All @@ -376,7 +377,7 @@ def _send(cls, comm, array, peer, dtype, count, stream=None):
def recv(cls, comm, out_array, peer, stream=None):
comm._check_contiguous(out_array)
stream = comm._get_stream(stream)
dtype, count = comm._get_nccl_dtype_and_count(out_array)
dtype, count = _get_nccl_dtype_and_count(out_array)
cls._recv(comm, out_array, peer, dtype, count, stream)

@classmethod
Expand All @@ -388,8 +389,8 @@ def send_recv(cls, comm, in_array, out_array, peer, stream=None):
comm._check_contiguous(in_array)
comm._check_contiguous(out_array)
stream = comm._get_stream(stream)
idtype, icount = comm._get_nccl_dtype_and_count(in_array)
odtype, ocount = comm._get_nccl_dtype_and_count(out_array)
idtype, icount = _get_nccl_dtype_and_count(in_array)
odtype, ocount = _get_nccl_dtype_and_count(out_array)
nccl.groupStart()
cls._send(comm, in_array, peer, idtype, icount, stream)
cls._recv(comm, out_array, peer, odtype, ocount, stream)
Expand All @@ -408,9 +409,9 @@ def scatter(cls, comm, in_array, out_array, root=0, stream=None):
if root == comm.rank:
for i in range(comm._n_devices):
array = in_array[i]
idtype, icount = comm._get_nccl_dtype_and_count(array)
idtype, icount = _get_nccl_dtype_and_count(array)
cls._send(comm, array, i, idtype, icount, stream)
dtype, count = comm._get_nccl_dtype_and_count(out_array)
dtype, count = _get_nccl_dtype_and_count(out_array)
cls._recv(comm, out_array, root, dtype, count, stream)
nccl.groupEnd()

Expand All @@ -428,9 +429,9 @@ def gather(cls, comm, in_array, out_array, root=0, stream=None):
if root == comm.rank:
for i in range(comm._n_devices):
array = out_array[i]
odtype, ocount = comm._get_nccl_dtype_and_count(array)
odtype, ocount = _get_nccl_dtype_and_count(array)
cls._recv(comm, array, i, odtype, ocount, stream)
dtype, count = comm._get_nccl_dtype_and_count(in_array)
dtype, count = _get_nccl_dtype_and_count(in_array)
cls._send(comm, in_array, root, dtype, count, stream)
nccl.groupEnd()

Expand All @@ -448,8 +449,8 @@ def all_to_all(cls, comm, in_array, out_array, stream=None):
comm._check_contiguous(in_array)
comm._check_contiguous(out_array)
stream = comm._get_stream(stream)
idtype, icount = comm._get_nccl_dtype_and_count(in_array[0])
odtype, ocount = comm._get_nccl_dtype_and_count(out_array[0])
idtype, icount = _get_nccl_dtype_and_count(in_array[0])
odtype, ocount = _get_nccl_dtype_and_count(out_array[0])
# TODO check out dtypes are the same as in dtypes
nccl.groupStart()
for i in range(comm._n_devices):
Expand Down Expand Up @@ -728,7 +729,7 @@ def _send(cls, comm, array, peer, dtype, count, stream=None):
dtype = array.dtype.char
if dtype not in _nccl_dtypes:
raise TypeError(f'Unknown dtype {array.dtype} for NCCL')
dtype, count = comm._get_nccl_dtype_and_count(array)
dtype, count = _get_nccl_dtype_and_count(array)
stream = comm._get_stream(stream)
comm._comm.send(array.data.ptr, count, dtype, peer, stream)

Expand Down Expand Up @@ -757,7 +758,7 @@ def _recv(cls, comm, out_array, peer, dtype, count, stream=None):
dtype = dtype.char
if dtype not in _nccl_dtypes:
raise TypeError(f'Unknown dtype {out_array.dtype} for NCCL')
dtype, count = comm._get_nccl_dtype_and_count(out_array)
dtype, count = _get_nccl_dtype_and_count(out_array)
stream = comm._get_stream(stream)
comm._comm.recv(out_array.data.ptr, count, dtype, peer, stream)

Expand Down
6 changes: 6 additions & 0 deletions cupyx/distributed/array/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from cupyx.distributed.array._array import DistributedArray # NOQA
from cupyx.distributed.array._array import distributed_array # NOQA
from cupyx.distributed.array._linalg import make_2d_index_map # NOQA
from cupyx.distributed.array._linalg import matmul # NOQA
from cupyx.distributed.array._modes import Mode # NOQA
from cupyx.distributed.array._modes import REPLICA, MIN, MAX, SUM, PROD # NOQA

0 comments on commit f3536c6

Please sign in to comment.