Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] add nccl primitives #280

Merged
merged 23 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
95 changes: 95 additions & 0 deletions examples/distributed/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Testing script for distributed components for hidet
To debug, set the environment variable NCCL_DEBUG=INFO
"""
import hidet
import multiprocessing
from multiprocessing import Process
import numpy
import argparse

import hidet.cuda.nccl
from hidet.cuda import nccl
from hidet.cuda.nccl import NcclUniqueId, NcclDataType, NcclRedOp, nccl_library_filename
from hidet.ffi import runtime_api
from hidet.lang import attrs
from hidet.ir.primitives.cuda.nccl import all_reduce
from hidet.ir.type import data_type
from hidet.utils import prod
from hidet.drivers import build_ir_module
from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs
from hidet.runtime import load_compiled_module

print("NCCL version:", nccl.nccl_version())

parser = argparse.ArgumentParser()
parser.add_argument("n_gpus", type=int)
parser.add_argument("reduce_op", choices=['sum', 'prod', 'max', 'min', 'avg'])
args = parser.parse_args()

def run(world_size, rank, shared_id, barrier):
numpy.random.seed(rank)

# Initialize unique id
if rank == 0:
nccl.init_unique_id(shared_id)

barrier.wait()
hidet.cuda.set_device(rank)

print('initialize', rank)
# Create NcclCommunicator and set the cuda context
# this part should be moved into CompiledGraph in the future
comm = nccl.create_comm(world_size, shared_id, rank)
comms_array = nccl.comms_to_array([comm])
runtime_api.set_nccl_comms(comms_array)

# Initialize send and receive buffer
device = f"cuda:{rank}"
send = hidet.randn([2, 2], device=device)
recv = hidet.empty([2, 2], device=device)

print(rank, send)

dtype = data_type('float32')
shape = [2, 2]
nbytes = dtype.nbytes * prod(shape)

# Define IRModule
with hidet.script_module() as script_module:
@hidet.script
def launch(send: dtype[shape], recv: dtype[shape]):
attrs.func_kind = 'public'
all_reduce(0, send, recv, nbytes, dtype, getattr(NcclRedOp, args.reduce_op))

# Build
ir_module = script_module.ir_module()
ir_module.target = 'cuda'
ir_module.include_dirs.extend(get_nccl_include_dirs())
ir_module.linking_dirs.extend(get_nccl_library_search_dirs())
ir_module.include_headers.append(["nccl.h"])
ir_module.linking_libs.append(":" + nccl_library_filename())
out_dir = f'./.cache/all_reduce_{rank}'

build_ir_module(ir_module, out_dir, target='cuda')
compiled_module = load_compiled_module(out_dir)

compiled_module(send, recv)
s = hidet.cuda.current_stream()
s.synchronize()
print(rank, recv)

world_size = args.n_gpus

# Barrier to ensure unique id is created
barrier = multiprocessing.Barrier(world_size)

# Create a unique id object in shared memory
shared_id = multiprocessing.Value(NcclUniqueId, lock=False)

processes = [Process(target=run, args=(world_size, i, shared_id, barrier)) for i in range(world_size)]

for p in processes:
p.start()
for p in processes:
p.join()
1 change: 0 additions & 1 deletion gallery/developer-guides/hidet-script-dynamic-kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def matmul_kernel(
# iterate over the k tiles
num_k_tiles = (k_size + block_k_size - 1) // block_k_size
for k_tile in range(num_k_tiles):

# load smem_a [block_m_size, block_k_size] from global memory
for i, k in auto_map(block_m_size, block_k_size, workers=num_threads).on(
threadIdx.x
Expand Down
14 changes: 14 additions & 0 deletions include/hidet/runtime/cuda/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ struct CudaContext: BaseContext {
/* The cuda stream the kernels will be launched on. */
void* stream = nullptr;

/* NCCL Comunicators*/
void ** nccl_comms = nullptr;

int num_comms = 0;

/**
* Get the instance of cuda context.
*/
Expand All @@ -40,3 +45,12 @@ DLL void* get_cuda_stream();
*/
DLL void* request_cuda_workspace(size_t nbytes, bool require_clean);

/**
* Set required NCCL communicators of the context.
*/
DLL void set_nccl_comms(int num_comms, void** comm);

/**
* Get the NCCL communicator by the index
*/
DLL void* get_nccl_comm(int idx);
71 changes: 58 additions & 13 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,15 @@ class SourceCompiler:
The base class of source compiler.
"""

def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None:
def compile(
self,
src_path: str,
out_lib_path: str,
include_dirs: Sequence[str] = (),
linking_dirs: Sequence[str] = (),
linking_libraries: Sequence[str] = (),
object_files: Sequence[str] = (),
) -> None:
raise NotImplementedError()

def run_compile_command(self, command: str, src_path, out_lib_path: str):
Expand Down Expand Up @@ -104,8 +112,16 @@ def _resolve_nvcc_path():
return path
raise FileNotFoundError('Can not find nvcc compiler.')

def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None:
if len(linking_objects) > 0 and out_lib_path.endswith('.o'):
def compile(
self,
src_path: str,
out_lib_path: str,
include_dirs: Sequence[str] = (),
linking_dirs: Sequence[str] = (),
linking_libraries: Sequence[str] = (),
object_files: Sequence[str] = (),
) -> None:
if len(object_files) > 0 and out_lib_path.endswith('.o'):
raise ValueError('Can not compile multiple objects into a single object file.')

cc = hidet.cuda.compute_capability()
Expand All @@ -118,9 +134,10 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st
# the path to nvcc compiler
self.nvcc_path,
# the included directories.
*['-I{}'.format(include_dir) for include_dir in self.include_dirs],
*['-I{}'.format(include_dir) for include_dir in self.include_dirs + list(include_dirs)],
# the library directories.
*['-L{}'.format(library_dir) for library_dir in self.library_dirs],
*['-L{}'.format(library_dir) for library_dir in self.library_dirs + list(linking_dirs)],
*['-l{}'.format(library) for library in linking_libraries],
# optimize host side code via -O3
'-O3',
# host compiler options: enable openmp, avx2, unroll loops and fast math
Expand Down Expand Up @@ -153,7 +170,7 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st
# generate shared library (lib.so).
'--shared' if out_lib_path.endswith('.so') else '--compile',
# the linking objects.
' '.join(linking_objects),
' '.join(object_files),
# the source path.
src_path,
# the output library path.
Expand All @@ -179,16 +196,25 @@ def _resolve_gcc_path():
return path
raise FileNotFoundError('Can not find g++ compiler.')

def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None:
if len(linking_objects) > 0 and out_lib_path.endswith('.o'):
def compile(
self,
src_path: str,
out_lib_path: str,
include_dirs: Sequence[str] = (),
linking_dirs: Sequence[str] = (),
linking_libraries: Sequence[str] = (),
object_files: Sequence[str] = (),
) -> None:
if len(object_files) > 0 and out_lib_path.endswith('.o'):
raise ValueError('Can not compile multiple objects into a single object file.')
command = [
# the path to nvcc compiler
self.gcc_path,
# the included directories.
*['-I{}'.format(include_dir) for include_dir in self.include_dirs],
*['-I{}'.format(include_dir) for include_dir in self.include_dirs + list(include_dirs)],
# the library directories.
*['-L{}'.format(library_dir) for library_dir in self.library_dirs],
*['-L{}'.format(library_dir) for library_dir in self.library_dirs + list(linking_dirs)],
*['-l{}'.format(library) for library in linking_libraries],
# apply -O3 optimization.
'-O3',
# support avx intrinsics
Expand All @@ -204,7 +230,7 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st
# generate shared library (lib.so).
'-shared' if out_lib_path.endswith('.so') else '--compile',
# the linking objects.
' '.join(linking_objects),
' '.join(object_files),
# the source path.
src_path,
# the output library path.
Expand All @@ -216,7 +242,13 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st


def compile_source(
source_file: str, output_library_file: str, target: str, object_files: Optional[Sequence[str]]
source_file: str,
output_library_file: str,
target: str,
include_dirs: Sequence[str] = (),
linking_dirs: Sequence[str] = (),
linking_libraries: Sequence[str] = (),
object_files: Sequence[str] = (),
) -> None:
"""
Compile the source code in 'src_path' file and output the library to 'out_lib_path'.
Expand All @@ -229,6 +261,12 @@ def compile_source(
The path to output library.
target: str
The target platform. Currently only support 'cpu' and 'gpu'.
include_dirs: Optional[Sequence[str]]
The include directories.
linking_dirs: Optional[Sequence[str]]
The library directories.
linking_libraries:
The libraries to link to the output library.
object_files: Optional[Sequence[str]]
The path to object files. If not None, the object files will be linked to the output library.
"""
Expand All @@ -247,4 +285,11 @@ def compile_source(
raise ValueError('Unknown target platform: {}'.format(target))

object_files = object_files or []
compiler.compile(source_file, output_library_file, object_files)
compiler.compile(
source_file,
output_library_file,
include_dirs=include_dirs,
linking_dirs=linking_dirs,
linking_libraries=linking_libraries,
object_files=object_files,
)
5 changes: 5 additions & 0 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,9 @@ def require_headers(self) -> Doc:
doc += Text('#include <hidet/runtime/cuda/context.h>') + NewLine()
doc += Text("#include <hidet/runtime/logging.h>") + NewLine()

for header in self.ir_module.include_headers:
doc += Text('#include <{}>').format(header) + NewLine()

if self.require_tf32:
# nvcc use float to 'store' tfloat32 data
doc += Text('typedef float tfloat32_t;') + NewLine()
Expand Down Expand Up @@ -768,6 +771,8 @@ def require_headers(self) -> Doc:
doc += Text('#include <hidet/runtime/cpu/float16.h>') + NewLine()
if self.require_bf16:
doc += Text('#include <hidet/runtime/cpu/bfloat16.h>') + NewLine()
for header in self.ir_module.include_headers:
doc += Text('#include <{}>').format(header) + NewLine()
if self.require_tf32:
doc += Text('typedef float tfloat32_t;') + NewLine()
doc += NewLine()
Expand Down
13 changes: 13 additions & 0 deletions python/hidet/cuda/nccl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .comm import create_comm, NcclUniqueId, NcclDataType, NcclRedOp, comms_to_array, init_unique_id, dtype_to_nccl
from .ffi import nccl_version, nccl_library_filename