Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] add nccl primitives #280

Merged
merged 23 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 64 additions & 0 deletions examples/distributed/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import hidet
import multiprocessing
from multiprocessing import Process
import hidet.cuda.nccl
from hidet.cuda.nccl import NcclUniqueId, create_comm, ncclDataType, ncclRedOp
from hidet.ffi.runtime_api import NCCLRuntimeAPI
from hidet.lang import attrs
from hidet.ir.primitives.cuda.nccl import all_reduce
from hidet.ir.type import data_type
from hidet.ir import Task
from hidet.graph.ops.utils import input_like, compute
from hidet.utils import prod
from hidet.drivers import build_task, build_ir_module
from hidet.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs
from hidet.runtime import load_compiled_module

print("NCCL version:", NCCLRuntimeAPI.get_version())

def run(world_size, rank, shared_id, barrier):
if rank == 0:
NCCLRuntimeAPI.get_unique_id(shared_id)
barrier.wait()
hidet.cuda.set_device(rank)

print('initialize', rank)
comm = create_comm(world_size, shared_id, rank)

device = f"cuda:{rank}"
send = hidet.randn([2, 2], device=device)
recv = hidet.empty([2, 2], device=device)

dtype = data_type('float32')
shape = [2, 2]
nbytes = dtype.nbytes * prod(shape)

with hidet.script_module() as script_module:
@hidet.script
def launch(send: dtype[shape], recv: dtype[shape]):
attrs.func_kind = 'public'
all_reduce(0, send, recv, nbytes, ncclDataType.float32, ncclRedOp.sum)

ir_module = script_module.ir_module()
ir_module.target = 'cuda'
ir_module.include_dirs.extend(get_nccl_include_dirs())
ir_module.linking_dirs.extend(get_nccl_library_search_dirs())
ir_module.include_headers.append(["nccl.h"])
ir_module.linking_libs.append(":libnccl.so.2")
out_dir = f'./.cache/all_reduce_{rank}'
build_ir_module(ir_module, out_dir, target='cuda')
compiled_module = load_compiled_module(out_dir)
compiled_module(send, recv)
s = hidet.cuda.current_stream()
s.synchronize()
print(recv)

world_size = 4
barrier = multiprocessing.Barrier(world_size)
shared_id = multiprocessing.Value(NcclUniqueId, lock=False)
processes = [Process(target=run, args=(world_size, i, shared_id, barrier)) for i in range(world_size)]

for p in processes:
p.start()
for p in processes:
p.join()
14 changes: 14 additions & 0 deletions include/hidet/runtime/cuda/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
#include <hidet/runtime/context.h>
// #include <cuda_runtime.h>

#include <vector>

struct CudaContext: BaseContext {
/* The cuda stream the kernels will be launched on. */
void* stream = nullptr;

/* NCCL Comunicators*/
std::vector<void *> nccl_comms;

/**
* Get the instance of cuda context.
*/
Expand All @@ -40,3 +45,12 @@ DLL void* get_cuda_stream();
*/
DLL void* request_cuda_workspace(size_t nbytes, bool require_clean);

/**
* Add a NCCL communicator to the context.
*/
DLL void add_nccl_comm(void* comm);

/**
* Get the NCCL communicator by the index
*/
DLL void* get_nccl_comm(int idx);
71 changes: 58 additions & 13 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,15 @@ class SourceCompiler:
The base class of source compiler.
"""

def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None:
def compile(
self,
src_path: str,
out_lib_path: str,
include_dirs: Sequence[str]=(),
linking_dirs: Sequence[str]=(),
linking_libraries: Sequence[str]=(),
object_files: Sequence[str]=()
) -> None:
raise NotImplementedError()

def run_compile_command(self, command: str, src_path, out_lib_path: str):
Expand Down Expand Up @@ -104,8 +112,16 @@ def _resolve_nvcc_path():
return path
raise FileNotFoundError('Can not find nvcc compiler.')

def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None:
if len(linking_objects) > 0 and out_lib_path.endswith('.o'):
def compile(
self,
src_path: str,
out_lib_path: str,
include_dirs: Sequence[str] = (),
linking_dirs: Sequence[str] = (),
linking_libraries: Sequence[str] = (),
object_files: Sequence[str] = ()
) -> None:
if len(object_files) > 0 and out_lib_path.endswith('.o'):
raise ValueError('Can not compile multiple objects into a single object file.')

cc = hidet.cuda.compute_capability()
Expand All @@ -118,9 +134,10 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st
# the path to nvcc compiler
self.nvcc_path,
# the included directories.
*['-I{}'.format(include_dir) for include_dir in self.include_dirs],
*['-I{}'.format(include_dir) for include_dir in self.include_dirs + list(include_dirs)],
# the library directories.
*['-L{}'.format(library_dir) for library_dir in self.library_dirs],
*['-L{}'.format(library_dir) for library_dir in self.library_dirs + list(linking_dirs)],
*['-l{}'.format(library) for library in linking_libraries],
# optimize host side code via -O3
'-O3',
# host compiler options: enable openmp, avx2, unroll loops and fast math
Expand Down Expand Up @@ -153,7 +170,7 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st
# generate shared library (lib.so).
'--shared' if out_lib_path.endswith('.so') else '--compile',
# the linking objects.
' '.join(linking_objects),
' '.join(object_files),
# the source path.
src_path,
# the output library path.
Expand All @@ -179,16 +196,25 @@ def _resolve_gcc_path():
return path
raise FileNotFoundError('Can not find g++ compiler.')

def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None:
if len(linking_objects) > 0 and out_lib_path.endswith('.o'):
def compile(
self,
src_path: str,
out_lib_path: str,
include_dirs: Sequence[str]=(),
linking_dirs: Sequence[str]=(),
linking_libraries: Sequence[str]=(),
object_files: Sequence[str]=()
) -> None:
if len(object_files) > 0 and out_lib_path.endswith('.o'):
raise ValueError('Can not compile multiple objects into a single object file.')
command = [
# the path to nvcc compiler
self.gcc_path,
# the included directories.
*['-I{}'.format(include_dir) for include_dir in self.include_dirs],
*['-I{}'.format(include_dir) for include_dir in self.include_dirs + list(include_dirs)],
# the library directories.
*['-L{}'.format(library_dir) for library_dir in self.library_dirs],
*['-L{}'.format(library_dir) for library_dir in self.library_dirs + list(linking_dirs)],
*['-l{}'.format(library) for library in linking_libraries],
# apply -O3 optimization.
'-O3',
# support avx intrinsics
Expand All @@ -204,7 +230,7 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st
# generate shared library (lib.so).
'-shared' if out_lib_path.endswith('.so') else '--compile',
# the linking objects.
' '.join(linking_objects),
' '.join(object_files),
# the source path.
src_path,
# the output library path.
Expand All @@ -216,7 +242,13 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st


def compile_source(
source_file: str, output_library_file: str, target: str, object_files: Optional[Sequence[str]]
source_file: str,
output_library_file: str,
target: str,
include_dirs: Sequence[str] = (),
linking_dirs: Sequence[str] = (),
linking_libraries: Sequence[str] = (),
object_files: Sequence[str] = (),
) -> None:
"""
Compile the source code in 'src_path' file and output the library to 'out_lib_path'.
Expand All @@ -229,6 +261,12 @@ def compile_source(
The path to output library.
target: str
The target platform. Currently only support 'cpu' and 'gpu'.
include_dirs: Optional[Sequence[str]]
The include directories.
linking_dirs: Optional[Sequence[str]]
The library directories.
linking_libraries:
The libraries to link to the output library.
object_files: Optional[Sequence[str]]
The path to object files. If not None, the object files will be linked to the output library.
"""
Expand All @@ -247,4 +285,11 @@ def compile_source(
raise ValueError('Unknown target platform: {}'.format(target))

object_files = object_files or []
compiler.compile(source_file, output_library_file, object_files)
compiler.compile(
source_file,
output_library_file,
include_dirs=include_dirs,
linking_dirs=linking_dirs,
linking_libraries=linking_libraries,
object_files=object_files
)
5 changes: 5 additions & 0 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,9 @@ def require_headers(self) -> Doc:
doc += Text('#include <hidet/runtime/cuda/complex.h>') + NewLine()
doc += Text('#include <hidet/runtime/cuda/context.h>') + NewLine()

for header in self.ir_module.include_headers:
doc += Text('#include <{}>').format(header) + NewLine()

if self.require_tf32:
# nvcc use float to 'store' tfloat32 data
doc += Text('typedef float tfloat32_t;') + NewLine()
Expand Down Expand Up @@ -762,6 +765,8 @@ def require_headers(self) -> Doc:
doc += Text('#include <hidet/runtime/cpu/float16.h>') + NewLine()
if self.require_bf16:
doc += Text('#include <hidet/runtime/cpu/bfloat16.h>') + NewLine()
for header in self.ir_module.include_headers:
doc += Text('#include <{}>').format(header) + NewLine()
if self.require_tf32:
doc += Text('typedef float tfloat32_t;') + NewLine()
doc += NewLine()
Expand Down
31 changes: 31 additions & 0 deletions python/hidet/cuda/nccl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from enum import IntEnum

from ..ffi import NcclUniqueId, NcclCommunicator, nccl_runtime_api

class ncclDataType(IntEnum):
int8 = 0
char = 0
uint8 = 1
int32 = 2
int = 2
uint32 = 3
int64 = 4
uint64 = 5
float16 = 6
half = 6
float32 = 7
float = 7
float64 = 8
double = 8
bfloat = 9

class ncclRedOp(IntEnum):
sum = 0
prod = 1
max = 2
min = 3
avg = 4

def create_comm(nranks: int, unique_id: NcclUniqueId, rank: int):
handle = nccl_runtime_api.comm_init_rank(nranks, unique_id, rank)
return NcclCommunicator(handle)
6 changes: 5 additions & 1 deletion python/hidet/drivers/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ def build_ir_module(
codegen(ir_module, src_out_path=src_path, target=target)

# compile source code
compile_source(src_path, output_library_file=lib_path, target=target, object_files=object_files)
compile_source(src_path, output_library_file=lib_path, target=target,
include_dirs=ir_module.include_dirs,
linking_dirs = ir_module.linking_dirs,
linking_libraries = ir_module.linking_libs,
object_files=object_files)

# write the function types
if output_kind == '.so':
Expand Down
8 changes: 2 additions & 6 deletions python/hidet/drivers/build_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ def get_output_shape(idx: int32, dims: ~int32):

launch_func.name = 'launch_0'
task_ir_module.functions['launch_0'] = launch_func

object_files = []
else:
# otherwise, build each candidate to a .o file, and link them into the task's ir module
for i, candidate in enumerate(candidates):
Expand Down Expand Up @@ -113,13 +111,11 @@ def launch(arg: meta.types(param_types)):
ir_module = script_module.ir_module()
ir_module.add_function(get_input_shape.name, get_input_shape)
ir_module.add_function(get_output_shape.name, get_output_shape)
ir_module.object_files.extend([os.path.join(task_dir, 'candidates', str(i), 'lib.o') for i in range(len(candidates))])
task_ir_module = ir_module
object_files = [os.path.join(task_dir, 'candidates', str(i), 'lib.o') for i in range(len(candidates))]

# build task ir module
build_ir_module(
ir_module=task_ir_module, output_dir=task_dir, output_kind='.so', object_files=object_files, target=target
)
build_ir_module(ir_module=task_ir_module, output_dir=task_dir, output_kind='.so', target=target)


def generate_meta_data(task: Task, task_dir: str, build_target: str, num_candidates: int):
Expand Down
6 changes: 4 additions & 2 deletions python/hidet/ffi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .ffi import _LIB
from .runtime_api import runtime_api
from .ffi import _LIB, _LIB_NCCL
from .runtime_api import runtime_api, nccl_available
if nccl_available():
from .runtime_api import nccl_runtime_api, NcclUniqueId, NcclCommunicator, get_nccl_comm

from . import callbacks
from . import crt
Expand Down
28 changes: 24 additions & 4 deletions python/hidet/ffi/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
from typing import List, Dict, Optional
import os
import os.path
import glob
import ctypes
from hidet.libinfo import get_library_search_dirs
from hidet.libinfo import get_library_search_dirs, get_nccl_library_search_dirs

_LIB: Optional[ctypes.CDLL] = None
_LIB_RUNTIME: Optional[ctypes.CDLL] = None
_LIB_NCCL: Optional[ctypes.CDLL] = None


library_paths: Dict[str, Optional[str]] = {'hidet': None, 'hidet_runtime': None}
library_paths: Dict[str, Optional[str]] = {'hidet': None, 'hidet_runtime': None, 'nccl': None}


def load_library():
Expand All @@ -40,6 +42,18 @@ def load_library():
if _LIB is None:
raise OSError('Can not find library in the following directory: \n' + '\n'.join(library_dirs))

def load_nccl_library():
global _LIB_NCCL
library_dirs = get_nccl_library_search_dirs()
for library_dir in library_dirs:
lib_nccl_paths = glob.glob(os.path.join(library_dir, 'libnccl.so*'))
if len(lib_nccl_paths) == 0:
continue
_LIB_NCCL = ctypes.cdll.LoadLibrary(lib_nccl_paths[0])
library_paths['nccl'] = lib_nccl_paths[0]
break
if _LIB_NCCL is None:
raise OSError('Can not find nccl library in the following directory: \n' + '\n'.join(library_dirs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put this part in "hidet/cuda/nccl/ffi.py", and leave "hidet/ffi/ffi.py" to only contain the hidet runtime library.

In the future, when we want to add other library (e.g., cudnn library), we can put to "hidet/cuda/cudnn" and also its ffi.py.


def get_last_error() -> Optional[str]:
func = getattr(get_last_error, '_func', None)
Expand Down Expand Up @@ -71,10 +85,12 @@ def get_func(func_name, arg_types: List, restype):
func = getattr(_LIB, func_name)
elif func_exists(func_name, _LIB_RUNTIME):
func = getattr(_LIB_RUNTIME, func_name)
elif func_exists(func_name, _LIB_NCCL):
func = getattr(_LIB_NCCL, func_name)
else:
raise ValueError(
'Can not find function "{}" in hidet libraries:\n{}\n{}'.format(
func_name, library_paths['hidet'], library_paths['hidet_runtime']
'Can not find function "{}" in hidet libraries:\n{}\n{}\n{}'.format(
func_name, library_paths['hidet'], library_paths['hidet_runtime'], library_paths['nccl']
)
)

Expand All @@ -96,3 +112,7 @@ def func_with_check(*args):


load_library()
load_nccl_library()

def nccl_available():
return _LIB_NCCL is not None