New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Distributed] add nccl primitives #280
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @soodoshll , thanks!
Leave some comments on how to organize the nccl-related source code. I prefer to putting them under hidet.cuda.nccl
submodule. This would make our code structure more clean when we add other vendor libraries like cublas and cudnn.
python/hidet/ffi/ffi.py
Outdated
def load_nccl_library(): | ||
global _LIB_NCCL | ||
library_dirs = get_nccl_library_search_dirs() | ||
for library_dir in library_dirs: | ||
lib_nccl_paths = glob.glob(os.path.join(library_dir, 'libnccl.so*')) | ||
if len(lib_nccl_paths) == 0: | ||
continue | ||
_LIB_NCCL = ctypes.cdll.LoadLibrary(lib_nccl_paths[0]) | ||
library_paths['nccl'] = lib_nccl_paths[0] | ||
break | ||
if _LIB_NCCL is None: | ||
raise OSError('Can not find nccl library in the following directory: \n' + '\n'.join(library_dirs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's put this part in "hidet/cuda/nccl/ffi.py", and leave "hidet/ffi/ffi.py" to only contain the hidet runtime library.
In the future, when we want to add other library (e.g., cudnn library), we can put to "hidet/cuda/cudnn" and also its ffi.py.
python/hidet/ffi/runtime_api.py
Outdated
class NcclUniqueId(Structure): | ||
""" | ||
Defined in nccl.h | ||
""" | ||
_fields_ = [("internal", c_byte * 128)] | ||
|
||
class NcclCommunicator: | ||
""" | ||
|
||
""" | ||
def __init__(self, handle: int): | ||
""" | ||
Users should not call this constructor directly. Because there are two ways of creating | ||
a new communicator: 1) using unique_id and rank ; 2) using split. | ||
""" | ||
if not nccl_available(): | ||
raise RuntimeError("NCCL Library not found.") | ||
self._handle = handle | ||
|
||
# TODO: how to ensure the following two are identical? | ||
_comms.append(self) | ||
runtime_api.add_nccl_comm(self) | ||
|
||
def __del__(self): | ||
""" | ||
Should we manage the lifetime of communicator object in Python or C++? | ||
""" | ||
nccl_runtime_api.comm_destroy(self) | ||
|
||
def split(self): | ||
raise NotImplementedError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider moving this to "hidet/cuda/nccl.py".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yaoyaoding I have some questions related to code organization:
- How should we manage the lifetime of NcclCommunicator objects in Python and ncclComm_t objects in C++? There are two choices:
- in Python as it is now. the resource is released when the python object is released. this means if we need a python api
get_nccl_comm(idx)
in Python, we need to maintain a global list_comms
in Python to maintain all communicators and make sure it's aligned with the ncclComm_t list in C++. However, It's likely that we don't need this API in Python,NcclCommunicator.get_id()
is more important; - in C++. We don't manage the lifetime of
ncclComm_t
at all, and only release them when exit.NcclCommunicator
will only be a wrapper ofncclComm_t
.
IMO, we should follow (1) and abandon get_nccl_comm(idx)
in Python. No global list _comms
is needed. If users think it's necessary, they can create one by themselves.
- Should we import 'nccl.py' in 'nccl/ffi.py' or the opposite? Because NcclCommunicator might depend on some runtime api, therefore we need 'nccl/ffi.py' in 'nccl.py' which might cause circular import.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good questions!
How should we manage the lifetime of NcclCommunicator objects in Python and ncclComm_t objects in C++?
Because different flow graphs may have different set of communicators, thus, we should create the communicators for each flow graph. Let's keep the communicators in CompiledGraph. Something like
class CompiledGraph:
def __init__(...):
...
self.nccl_comms = ... # create the communicators,
# let's store it as hidet.ffi.utils.Array(void*, num of comm)
# and create a runtime api called
# "void set_nccl_comms(int num_comm, void** comm_array)"
def run_async(...):
...
runtime_api.set_nccl_comms(len(self.nccl_comms), self.nccl_comms)
...
def __del__(...):
# destroy communicators
For FlowGraph.forward(...)
, we can raise an error for now if we find it is a distributed flow graph.
Should we import 'nccl.py' in 'nccl/ffi.py' or the opposite? Because NcclCommunicator might depend on some runtime api, therefore we need 'nccl/ffi.py' in 'nccl.py' which might cause circular import.
Consider: let the nccl runtime api return integer (the pointer to communicator), and the hidet/cuda/nccl/nccl.py
imports hidet/cuda/nccl/ffi.py
. The nccl ffi should only be used by nccl.py
and we expose the api to the users in nccl.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your reply! I agree with the idea of attaching communicators to compiled graphs.
Maybe we can leave the modification of CompiledGraph for another PR (which might include upper-level structures like op, graph) ? Let's focus this PR on primitives, otherwise it will be huge can hard to test.
python/hidet/ffi/runtime_api.py
Outdated
if nccl_available(): | ||
class NCCLRuntimeAPI: | ||
""" | ||
Runtime APIs regarding NCCL | ||
TODO: Exception handling | ||
""" | ||
_get_version = get_func('ncclGetVersion', [c_void_p], c_int) | ||
_get_unique_id = get_func('ncclGetUniqueId', [c_void_p], c_int) | ||
_comm_init_rank = get_func('ncclCommInitRank', [c_void_p, c_int, NcclUniqueId, c_int], c_int) | ||
_comm_destroy = get_func('ncclCommDestroy', [c_void_p], c_int) | ||
|
||
_comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int) | ||
_comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int) | ||
|
||
@staticmethod | ||
def get_version() -> int: | ||
version = c_int(0) | ||
NCCLRuntimeAPI._get_version(pointer(version)) | ||
return version.value | ||
|
||
@staticmethod | ||
def get_unique_id(comm_id:NcclUniqueId) -> None: | ||
""" | ||
In-place initialization of the NcclUniqueId object | ||
""" | ||
ret = NCCLRuntimeAPI._get_unique_id(pointer(comm_id)) | ||
assert ret == 0, ret | ||
|
||
@staticmethod | ||
def comm_init_rank(ndev:int, comm_id:NcclUniqueId, rank:int) -> int: | ||
comm = c_void_p() | ||
ret = NCCLRuntimeAPI._comm_init_rank(pointer(comm), ndev, comm_id, rank) | ||
assert ret == 0, ret | ||
return comm.value | ||
|
||
@staticmethod | ||
def comm_destroy(comm:NcclCommunicator) -> None: | ||
ret = NCCLRuntimeAPI._comm_destroy(comm._handle) | ||
assert ret == 0 | ||
|
||
nccl_runtime_api = NCCLRuntimeAPI() | ||
_comms: List[NcclCommunicator] = [] | ||
|
||
def get_nccl_comm(comm_id: int): | ||
return _comms[comm_id] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also move to hidet.cuda.nccl
python/hidet/libinfo.py
Outdated
|
||
def _get_nccl_dirs(): | ||
import site | ||
return [os.path.join(root, 'nvidia', 'nccl') for root in site.getsitepackages()] | ||
|
||
def get_nccl_include_dirs(): | ||
return [os.path.join(root, 'include') for root in _get_nccl_dirs()] | ||
|
||
def get_nccl_library_search_dirs(): | ||
return [os.path.join(root, 'lib') for root in _get_nccl_dirs()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to hidet.cuda.nccl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @soodoshll !
Overall LGTM. Left some minor suggestions. We can merge this first and then work on other primitives and graph level operators.
include/hidet/runtime/cuda/context.h
Outdated
/** | ||
* Add a NCCL communicator to the context. | ||
*/ | ||
DLL void set_nccl_comms(void** comm, int num_comms); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DLL void set_nccl_comms(void** comm, int num_comms); | |
DLL void set_nccl_comms(int num_comms, void** comm); |
Let's keep the order consistent with other APIs in hidet.
from hidet.cuda.nccl import NcclDataType, NcclRedOp | ||
|
||
|
||
def all_reduce(comm_id: int, sendbuff: Expr, recvbuff: Expr, count: Expr, dtype: NcclDataType, op: NcclRedOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def all_reduce(comm_id: int, sendbuff: Expr, recvbuff: Expr, count: Expr, dtype: NcclDataType, op: NcclRedOp): | |
def all_reduce(comm_id: int, sendbuff: Expr, recvbuff: Expr, count: Expr, dtype: DataType, op: NcclRedOp): |
Let's pass the hidet.ir.type.DataType
to this primitive and convert it into NcclDataType inside the primitive function.
} | ||
|
||
DLL void* get_nccl_comm(int idx) { | ||
return CudaContext::global()->nccl_comms[idx]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to add some check logic here to make sure idx < num_comms
.
@yaoyaoding fixed. Please take a look :) |
What we have now:
nvidia-nccl-cu11
ornvidia-nccl-cu12
./examples/distributed/test.py
)