Skip to content

Commit

Permalink
[Fix] NCCL API mismatch and NCCL primitive fix (#301)
Browse files Browse the repository at this point in the history
* `count` in NCCL APIs refer to number of elements, not bytes;
* Fix nccl primitives. It now can handle dynamic shapes correctly.
  • Loading branch information
soodoshll committed Jul 1, 2023
1 parent 790e775 commit 664f9f0
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion python/hidet/distributed/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def all_reduce(self, tensor: Tensor, op: str):
assert not tensor.is_symbolic()
assert tensor.device.is_cuda()
addr = tensor.storage.addr
self._comm.all_reduce(addr, addr, tensor.nbytes, tensor.dtype, op)
self._comm.all_reduce(addr, addr, tensor.size, tensor.dtype, op)


def create_nccl_group(store: Store, world_size: int, rank: int):
Expand Down
7 changes: 2 additions & 5 deletions python/hidet/graph/ops/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,21 @@ def __init__(self, x: TensorNode, op: str, comm_id: int = 0):

super().__init__('all_reduce', inputs=[x], outputs=[y], attributes={'comm_id': comm_id, 'op': op})

def __str__(self):
return "all_reduce"

def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]:
import hidet
from hidet.ir.primitives.cuda.nccl import all_reduce as _all_reduce
from hidet.lang import attrs

dtype: DataType = self.inputs[0].type.dtype
shape: Tuple[Expr, ...] = self.inputs[0].shape
nbytes = dtype.nbytes * prod(shape)
size = prod(shape)

with hidet.script_module() as script_module:

@hidet.script
def launch(x: dtype[shape], y: dtype[shape]):
attrs.func_kind = 'public'
_all_reduce(x, y, nbytes, dtype, str_to_nccl_op(self.op), self.comm_id)
_all_reduce(x, y, size, dtype, str_to_nccl_op(self.op), self.comm_id)

return [script_module.ir_module()]

Expand Down
12 changes: 8 additions & 4 deletions python/hidet/ir/primitives/cuda/nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ def all_reduce(sendbuff: Expr, recvbuff: Expr, count: Expr, dtype: DataType, op:

comm = get_nccl_comm(comm_id)
return BlackBoxStmt(
'ncclAllReduce({}, {}, {}, (ncclDataType_t){}, (ncclRedOp_t){}, '
'(ncclComm_t){}, (cudaStream_t){});'.format(
sendbuff, recvbuff, count, int(dtype_to_nccl(dtype)), int(op), comm, get_cuda_stream()
)
'ncclAllReduce({}, {}, {}, (ncclDataType_t){}, (ncclRedOp_t){}, ' '(ncclComm_t){}, (cudaStream_t){});',
sendbuff,
recvbuff,
count,
int(dtype_to_nccl(dtype)),
int(op),
comm,
get_cuda_stream(),
)

0 comments on commit 664f9f0

Please sign in to comment.