Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] all_reduce op and distributed info in graphs #284

Merged
merged 42 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
343b71d
init
soodoshll Jun 19, 2023
29ef7f5
op
soodoshll Jun 19, 2023
8e06560
update
soodoshll Jun 19, 2023
c8559d1
graph
soodoshll Jun 19, 2023
d97a7f8
update
soodoshll Jun 19, 2023
70f3a91
format
soodoshll Jun 19, 2023
6819ab8
add distributed graph
soodoshll Jun 19, 2023
95cec0c
update
soodoshll Jun 21, 2023
961e99b
support split
soodoshll Jun 22, 2023
334a1eb
update
soodoshll Jun 22, 2023
7dd55c4
update
soodoshll Jun 22, 2023
0c57cff
relaunch test
soodoshll Jun 22, 2023
7b81b0d
Merge branch 'main' of github.com:hidet-org/hidet into nccl-op
soodoshll Jun 22, 2023
047ea87
update
soodoshll Jun 22, 2023
dba85a5
fix
soodoshll Jun 22, 2023
2c6e5b1
format
soodoshll Jun 22, 2023
5d51ed4
fix
soodoshll Jun 22, 2023
f4bf865
[Document] fix installation guide (#288)
soodoshll Jun 22, 2023
64b9f03
[Runtime] Check for input tensor device (#287)
hjjq Jun 22, 2023
57ae2a9
fix
soodoshll Jun 23, 2023
a3d0a71
fix
soodoshll Jun 23, 2023
2ffcfe3
fix
soodoshll Jun 23, 2023
ee60249
update
soodoshll Jun 23, 2023
f3aad89
[FixBug] Don't instantiate symbol for primitive functions (#291)
hjjq Jun 26, 2023
64a632a
file store
soodoshll Jun 27, 2023
c028827
file store
soodoshll Jun 27, 2023
f118fd9
Merge branch 'nccl-op' into fs-store
soodoshll Jun 27, 2023
56a96ca
update
soodoshll Jun 27, 2023
a39c199
update
soodoshll Jun 27, 2023
0a04b82
update
soodoshll Jun 27, 2023
eedaf84
add test
soodoshll Jun 27, 2023
37c8654
format & copyright
soodoshll Jun 27, 2023
3fd7491
update
soodoshll Jun 27, 2023
8bc856f
update
soodoshll Jun 27, 2023
bb4d6d1
format
soodoshll Jun 27, 2023
8518e9e
update
soodoshll Jun 27, 2023
dcb87aa
fix
soodoshll Jun 27, 2023
a2d8be6
format
soodoshll Jun 27, 2023
917d24f
fix
soodoshll Jun 28, 2023
fdf749f
fix
soodoshll Jun 28, 2023
816da19
remove redundant seek
soodoshll Jun 28, 2023
c3eee0d
fix
soodoshll Jun 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
90 changes: 44 additions & 46 deletions examples/distributed/test.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
"""
Testing script for distributed components for hidet
To debug, set the environment variable NCCL_DEBUG=INFO

To install nccl, run

pip install nvidia-nccl-cu11==2.18.3

Or

pip install nvidia-nccl-cu12==2.18.3
"""
import hidet
import multiprocessing
from multiprocessing import Process
import numpy
import argparse

import hidet
import hidet.cuda.nccl
from hidet.cuda import nccl
from hidet.cuda.nccl import NcclUniqueId, NcclDataType, NcclRedOp, nccl_library_filename
from hidet.ffi import runtime_api
from hidet.lang import attrs
from hidet.ir.primitives.cuda.nccl import all_reduce
from hidet.ir.type import data_type
from hidet.utils import prod
from hidet.drivers import build_ir_module
from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs
from hidet.runtime import load_compiled_module
from hidet.cuda.nccl import NcclUniqueId

print("NCCL version:", nccl.nccl_version())

parser = argparse.ArgumentParser()
parser.add_argument("n_gpus", type=int)
parser.add_argument("reduce_op", choices=['sum', 'prod', 'max', 'min', 'avg'])
parser.add_argument("--group_size", type=int, default=0)
args = parser.parse_args()

def run(world_size, rank, shared_id, barrier):
Expand All @@ -37,47 +39,43 @@ def run(world_size, rank, shared_id, barrier):
barrier.wait()
hidet.cuda.set_device(rank)

print('initialize', rank)
# Create NcclCommunicator and set the cuda context
# this part should be moved into CompiledGraph in the future
comm = nccl.create_comm(world_size, shared_id, rank)
comms_array = nccl.comms_to_array([comm])
runtime_api.set_nccl_comms(comms_array)
use_group = args.group_size > 1
if use_group:
gs = args.group_size
gn = world_size // gs
assert world_size % gs == 0
groups = [list(range(i * gs, (i + 1) * gs)) for i in range(gn)]
else:
groups = []


# Initialize send and receive buffer
device = f"cuda:{rank}"
send = hidet.randn([2, 2], device=device)
recv = hidet.empty([2, 2], device=device)

print(rank, send)

dtype = data_type('float32')
shape = [2, 2]
nbytes = dtype.nbytes * prod(shape)

# Define IRModule
with hidet.script_module() as script_module:
@hidet.script
def launch(send: dtype[shape], recv: dtype[shape]):
attrs.func_kind = 'public'
all_reduce(0, send, recv, nbytes, dtype, getattr(NcclRedOp, args.reduce_op))

# Build
ir_module = script_module.ir_module()
ir_module.target = 'cuda'
ir_module.include_dirs.extend(get_nccl_include_dirs())
ir_module.linking_dirs.extend(get_nccl_library_search_dirs())
ir_module.include_headers.append(["nccl.h"])
ir_module.linking_libs.append(":" + nccl_library_filename())
out_dir = f'./.cache/all_reduce_{rank}'

build_ir_module(ir_module, out_dir, target='cuda')
compiled_module = load_compiled_module(out_dir)

compiled_module(send, recv)
x = hidet.randn([1, 3], device=device)
w = hidet.randn([3, 2], device=device)

# Create Computation Graph
x_symb = hidet.symbol_like(x)
w_symb = hidet.symbol_like(w)
y_local = hidet.ops.relu(x_symb @ w_symb)
y_sync = hidet.ops.all_reduce(y_local, args.reduce_op, comm_id=int(use_group))
graph = hidet.trace_from([y_local, y_sync], inputs=[x_symb, w_symb])
opt_graph = hidet.graph.optimize(graph)
opt_graph.set_dist_attrs(nrank=world_size, rank=rank, groups=groups)
compiled = opt_graph.build()

# test save and load
compiled_dir = f"./outs/graph_{rank}.zip"
compiled.save(compiled_dir)
compiled = hidet.runtime.load_compiled_graph(compiled_dir)

# Create Distributed Graph
compiled.init_dist(shared_id)

y_local, y_sync = compiled(x, w)

s = hidet.cuda.current_stream()
s.synchronize()
print(rank, recv)
print(f"process {rank}\nbefore allreduce:{y_local}\nafter allreduce:{y_sync}\n", end='')

world_size = args.n_gpus

Expand Down
15 changes: 13 additions & 2 deletions python/hidet/cuda/nccl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .comm import create_comm, NcclUniqueId, NcclDataType, NcclRedOp, comms_to_array, init_unique_id, dtype_to_nccl
from .ffi import nccl_version, nccl_library_filename
from .ffi import nccl_available, nccl_version, nccl_library_filename
from .comm import (
create_comm,
NcclUniqueId,
NcclDataType,
NcclRedOp,
comms_to_array,
init_unique_id,
dtype_to_nccl,
NcclCommunicator,
str_to_nccl_op,
NCCL_SPLIT_NOCOLOR,
)
27 changes: 23 additions & 4 deletions python/hidet/cuda/nccl/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@

from hidet.ffi.utils import Array
from hidet.ir.type import void_p, DataType
from .ffi import nccl_runtime_api, NcclUniqueId
from .ffi import nccl_available, NcclUniqueId

NCCL_SPLIT_NOCOLOR = -1

if nccl_available():
from .ffi import nccl_runtime_api


class NcclDataType(IntEnum):
Expand Down Expand Up @@ -44,13 +49,20 @@ class NcclRedOp(IntEnum):
avg = 4


def str_to_nccl_op(name: str) -> NcclRedOp:
if name not in ('sum', 'prod', 'max', 'min', 'avg'):
raise RuntimeError(f"'{name}' is not a supported reduce op")
return getattr(NcclRedOp, name)


class NcclCommunicator:
def __init__(self, handle: int):
"""
Users should not call this constructor directly. Because there are two ways of creating
a new communicator: 1) using unique_id and rank ; 2) using split.
"""

if not nccl_available():
raise RuntimeError("NCCL is not available")
self._handle = handle

def __del__(self):
Expand All @@ -60,11 +72,16 @@ def __del__(self):
def handle(self):
return self._handle

def split(self):
raise NotImplementedError()
def split(self, key, color):
new_handle = nccl_runtime_api.comm_split(self._handle, color, key)
if color == NCCL_SPLIT_NOCOLOR:
return None
return NcclCommunicator(new_handle)


def create_comm(nranks: int, unique_id: NcclUniqueId, rank: int) -> NcclCommunicator:
if not nccl_available():
raise RuntimeError("NCCL is not available")
handle = nccl_runtime_api.comm_init_rank(nranks, unique_id, rank)
return NcclCommunicator(handle)

Expand All @@ -77,6 +94,8 @@ def comms_to_array(comms: List[NcclCommunicator]) -> Array:


def init_unique_id(unqie_id: NcclUniqueId) -> None:
if not nccl_available():
raise RuntimeError("NCCL is not available")
nccl_runtime_api.get_unique_id(unqie_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define init_unique_id(...) as

def create_unique_id() -> NcclUniqueId:
    ...

I feel the current API is not very intuitive.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point here is now we need the NcclUniqueId to be shared by all processes. And the current solution is

  1. Create a shared NcclUniqueId object;
  2. Launch multiple processes with the shared uniqueid object as one argument;
  3. Init the shared uniqueid object in process 0, which need the reference to the shared object
    If we create the NcclUniqueId in process 0 after processes have been launched, it's not so easy to do the broadcast (if there's an elegant way of broadcasting, please let me know).

A workaround is to 1) create the shared object; 2) launch processes; 3) create a unique id object; 4) copy its value back to the shared object.



Expand Down
98 changes: 54 additions & 44 deletions python/hidet/cuda/nccl/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def load_nccl_library():
_LIB_NCCL = ctypes.cdll.LoadLibrary(lib_nccl_paths[0])
nccl_library_path = lib_nccl_paths[0]
break
if _LIB_NCCL is None:
raise OSError('Can not find nccl library in the following directory: \n' + '\n'.join(library_dirs))


load_nccl_library()
Expand All @@ -60,49 +58,61 @@ def nccl_library_filename():
return os.path.basename(nccl_library_path)


if not nccl_available():
raise RuntimeError("NCCL Library not found.")
if nccl_available():


class NCCLRuntimeAPI:
"""
Runtime APIs regarding NCCL
TODO: Exception handling
"""

_get_version = get_func('ncclGetVersion', [c_void_p], c_int, lib=_LIB_NCCL)
_get_unique_id = get_func('ncclGetUniqueId', [c_void_p], c_int, lib=_LIB_NCCL)
_comm_init_rank = get_func('ncclCommInitRank', [c_void_p, c_int, NcclUniqueId, c_int], c_int, lib=_LIB_NCCL)
_comm_destroy = get_func('ncclCommDestroy', [c_void_p], c_int, lib=_LIB_NCCL)

_comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL)
_comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL)

@staticmethod
def get_version() -> int:
version = c_int(0)
NCCLRuntimeAPI._get_version(pointer(version))
return version.value

@staticmethod
def get_unique_id(comm_id: NcclUniqueId) -> None:
class NCCLRuntimeAPI:
"""
In-place initialization of the NcclUniqueId object
Runtime APIs regarding NCCL
TODO: Exception handling
"""
ret = NCCLRuntimeAPI._get_unique_id(pointer(comm_id))
assert ret == 0, ret

@staticmethod
def comm_init_rank(ndev: int, comm_id: NcclUniqueId, rank: int) -> int:
comm = c_void_p()
ret = NCCLRuntimeAPI._comm_init_rank(pointer(comm), ndev, comm_id, rank)
assert ret == 0, ret
return comm.value

@staticmethod
def comm_destroy(comm_handle) -> None:
ret = NCCLRuntimeAPI._comm_destroy(comm_handle)
assert ret == 0


nccl_runtime_api = NCCLRuntimeAPI()
_get_version = get_func('ncclGetVersion', [c_void_p], c_int, lib=_LIB_NCCL)
_get_unique_id = get_func('ncclGetUniqueId', [c_void_p], c_int, lib=_LIB_NCCL)
_comm_init_rank = get_func('ncclCommInitRank', [c_void_p, c_int, NcclUniqueId, c_int], c_int, lib=_LIB_NCCL)
_comm_destroy = get_func('ncclCommDestroy', [c_void_p], c_int, lib=_LIB_NCCL)

_comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL)
_comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL)

# Early versions of NCCL do not have split
try:
_comm_split = get_func('ncclCommSplit', [c_void_p, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL)
except ValueError:
_comm_split = None

@staticmethod
def get_version() -> int:
version = c_int(0)
NCCLRuntimeAPI._get_version(pointer(version))
return version.value

@staticmethod
def get_unique_id(comm_id: NcclUniqueId) -> None:
"""
In-place initialization of the NcclUniqueId object
"""
ret = NCCLRuntimeAPI._get_unique_id(pointer(comm_id))
assert ret == 0, ret

@staticmethod
def comm_init_rank(ndev: int, comm_id: NcclUniqueId, rank: int) -> int:
comm = c_void_p()
ret = NCCLRuntimeAPI._comm_init_rank(pointer(comm), ndev, comm_id, rank)
assert ret == 0, ret
return comm.value

@staticmethod
def comm_destroy(comm_handle) -> None:
ret = NCCLRuntimeAPI._comm_destroy(comm_handle)
assert ret == 0

@staticmethod
def comm_split(comm_handle: int, color: int, key: int) -> int:
if NCCLRuntimeAPI._comm_split is None:
raise RuntimeError("split is not supported on this version of NCCL. Please install a newer version.")
comm = c_void_p()
ret = NCCLRuntimeAPI._comm_split(comm_handle, color, key, pointer(comm), None)
assert ret == 0
return comm.value

nccl_runtime_api = NCCLRuntimeAPI()
18 changes: 17 additions & 1 deletion python/hidet/drivers/build_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from hidet.graph.tensor import Tensor
from hidet.graph.flow_graph import FlowGraph
from hidet.runtime.compiled_module import CompiledModule
from hidet.runtime.compiled_graph import CompiledGraph, GraphMetaData, GraphExecution, GraphExecutionInstruction
from hidet.runtime.compiled_graph import (
CompiledGraph,
GraphMetaData,
GraphExecution,
GraphExecutionInstruction,
GraphDistributedInfo,
)
from hidet.runtime.compiled_task import CompiledTask, TensorSignature
from hidet.graph.operator import Operator
from hidet.ir import primitives
Expand Down Expand Up @@ -142,6 +148,12 @@ def get_graph_meta_data(graph: FlowGraph, num_kernels, space: int) -> GraphMetaD
)


def get_graph_dist_info(graph: FlowGraph) -> GraphDistributedInfo:
if not graph.is_distributed():
return None
return GraphDistributedInfo(nrank=graph.nrank, rank=graph.rank, groups=graph.groups)


def build_graph_module(graph: FlowGraph, graph_weights: List[Tensor], node2kernel: List[int]) -> CompiledModule:
from hidet.lang import void_p, attrs, int32, int64, meta, cast
from hidet.ir.primitives.runtime import memory_planner_init, memory_planner_allocate, memory_planner_free
Expand Down Expand Up @@ -329,6 +341,9 @@ def build_flow_graph(graph, *, space=0) -> CompiledGraph:
# get the graph meta data
graph_meta_data = get_graph_meta_data(graph, len(graph_kernels), space)

# get distributed information
graph_dist_info = get_graph_dist_info(graph)

# build the compiled graph
compiled_graph = CompiledGraph(
meta=graph_meta_data,
Expand All @@ -337,6 +352,7 @@ def build_flow_graph(graph, *, space=0) -> CompiledGraph:
compiled_tasks=graph_kernels,
graph_execution=graph_execution,
graph_string=str(graph),
dist_info=graph_dist_info,
)

# save the compiled graph to cache
Expand Down