Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] all_reduce op and distributed info in graphs #284

Merged
merged 42 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
343b71d
init
soodoshll Jun 19, 2023
29ef7f5
op
soodoshll Jun 19, 2023
8e06560
update
soodoshll Jun 19, 2023
c8559d1
graph
soodoshll Jun 19, 2023
d97a7f8
update
soodoshll Jun 19, 2023
70f3a91
format
soodoshll Jun 19, 2023
6819ab8
add distributed graph
soodoshll Jun 19, 2023
95cec0c
update
soodoshll Jun 21, 2023
961e99b
support split
soodoshll Jun 22, 2023
334a1eb
update
soodoshll Jun 22, 2023
7dd55c4
update
soodoshll Jun 22, 2023
0c57cff
relaunch test
soodoshll Jun 22, 2023
7b81b0d
Merge branch 'main' of github.com:hidet-org/hidet into nccl-op
soodoshll Jun 22, 2023
047ea87
update
soodoshll Jun 22, 2023
dba85a5
fix
soodoshll Jun 22, 2023
2c6e5b1
format
soodoshll Jun 22, 2023
5d51ed4
fix
soodoshll Jun 22, 2023
f4bf865
[Document] fix installation guide (#288)
soodoshll Jun 22, 2023
64b9f03
[Runtime] Check for input tensor device (#287)
hjjq Jun 22, 2023
57ae2a9
fix
soodoshll Jun 23, 2023
a3d0a71
fix
soodoshll Jun 23, 2023
2ffcfe3
fix
soodoshll Jun 23, 2023
ee60249
update
soodoshll Jun 23, 2023
f3aad89
[FixBug] Don't instantiate symbol for primitive functions (#291)
hjjq Jun 26, 2023
64a632a
file store
soodoshll Jun 27, 2023
c028827
file store
soodoshll Jun 27, 2023
f118fd9
Merge branch 'nccl-op' into fs-store
soodoshll Jun 27, 2023
56a96ca
update
soodoshll Jun 27, 2023
a39c199
update
soodoshll Jun 27, 2023
0a04b82
update
soodoshll Jun 27, 2023
eedaf84
add test
soodoshll Jun 27, 2023
37c8654
format & copyright
soodoshll Jun 27, 2023
3fd7491
update
soodoshll Jun 27, 2023
8bc856f
update
soodoshll Jun 27, 2023
bb4d6d1
format
soodoshll Jun 27, 2023
8518e9e
update
soodoshll Jun 27, 2023
dcb87aa
fix
soodoshll Jun 27, 2023
a2d8be6
format
soodoshll Jun 27, 2023
917d24f
fix
soodoshll Jun 28, 2023
fdf749f
fix
soodoshll Jun 28, 2023
816da19
remove redundant seek
soodoshll Jun 28, 2023
c3eee0d
fix
soodoshll Jun 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 5 additions & 11 deletions docs/source/getting-started/build-from-source.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Build from source
-------------------
.. _Build-from-source:

If you want to contribute to Hidet, or you encountered any problem installing hidet via pip, it is better to install
If you want to contribute to Hidet, or you encountered any problem directly installing hidet via pip, it is better to install
hidet from source.

Clone the code
Expand Down Expand Up @@ -32,21 +32,15 @@ shared library:

After building, you could find two libraries ``libhidet.so`` and ``libhidet_runtime.so`` under ``build/lib`` directory.

Update environment variables
Install the Hidet Python package
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To allow Python interpreter to find hidet package under ``python`` directory of the repository, we should append the
directory to ``PYTHONPATH`` variable. To allow the system find the shared libraries we built in the previous step,
we should append ``build/lib`` directory to ``LD_LIBRARY_PATH`` variable.
Next we will install the Python package of Hidet in the develop mode via pip:

.. code-block:: console

$ export HIDET_HOME=<The Path to Hidet Repo>
$ export PYTHONPATH=$PYTHONPATH:$HIDET_HOME/python
$ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HIDET_HOME/build/lib

To avoid repeating above commands, it is recommended to put above commands to your shell's initialization script
(e.g., ``~/.bashrc`` for Bash and ``~/.zshrc`` for Zsh).
$ cd .. # return to the root directory of Hidet
$ pip install -e .

Validation
~~~~~~~~~~
Expand Down
104 changes: 38 additions & 66 deletions examples/distributed/test.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,65 @@
"""
Testing script for distributed components for hidet
To debug, set the environment variable NCCL_DEBUG=INFO

To install nccl, run

pip install nvidia-nccl-cu11==2.18.3

Or

pip install nvidia-nccl-cu12==2.18.3
"""
import hidet
import multiprocessing
from multiprocessing import Process
import numpy
import argparse
import atexit
import os

import hidet
import hidet.cuda.nccl
from hidet.cuda import nccl
from hidet.cuda.nccl import NcclUniqueId, NcclDataType, NcclRedOp, nccl_library_filename
from hidet.ffi import runtime_api
from hidet.lang import attrs
from hidet.ir.primitives.cuda.nccl import all_reduce
from hidet.ir.type import data_type
from hidet.utils import prod
from hidet.drivers import build_ir_module
from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs
from hidet.runtime import load_compiled_module

print("NCCL version:", nccl.nccl_version())

parser = argparse.ArgumentParser()
parser.add_argument("n_gpus", type=int)
parser.add_argument("reduce_op", choices=['sum', 'prod', 'max', 'min', 'avg'])
args = parser.parse_args()

def run(world_size, rank, shared_id, barrier):
def run(world_size, rank):
numpy.random.seed(rank)

# Initialize unique id
if rank == 0:
nccl.init_unique_id(shared_id)

barrier.wait()
hidet.cuda.set_device(rank)
hidet.distributed.init_process_group(init_method='file://tmp', world_size=world_size, rank=rank)
hidet.distributed.set_nccl_comms()

print('initialize', rank)
# Create NcclCommunicator and set the cuda context
# this part should be moved into CompiledGraph in the future
comm = nccl.create_comm(world_size, shared_id, rank)
comms_array = nccl.comms_to_array([comm])
runtime_api.set_nccl_comms(comms_array)

# Initialize send and receive buffer
device = f"cuda:{rank}"
send = hidet.randn([2, 2], device=device)
recv = hidet.empty([2, 2], device=device)

print(rank, send)

dtype = data_type('float32')
shape = [2, 2]
nbytes = dtype.nbytes * prod(shape)

# Define IRModule
with hidet.script_module() as script_module:
@hidet.script
def launch(send: dtype[shape], recv: dtype[shape]):
attrs.func_kind = 'public'
all_reduce(0, send, recv, nbytes, dtype, getattr(NcclRedOp, args.reduce_op))

# Build
ir_module = script_module.ir_module()
ir_module.target = 'cuda'
ir_module.include_dirs.extend(get_nccl_include_dirs())
ir_module.linking_dirs.extend(get_nccl_library_search_dirs())
ir_module.include_headers.append(["nccl.h"])
ir_module.linking_libs.append(":" + nccl_library_filename())
out_dir = f'./.cache/all_reduce_{rank}'

build_ir_module(ir_module, out_dir, target='cuda')
compiled_module = load_compiled_module(out_dir)

compiled_module(send, recv)
s = hidet.cuda.current_stream()
s.synchronize()
print(rank, recv)
x = hidet.randn([1, 3], device=device)
w = hidet.randn([3, 2], device=device)

# test runtime distributed op
hidet.distributed.all_reduce(w, 'avg')
print(w)

# Create Computation Graph
x_symb = hidet.symbol_like(x)
w_symb = hidet.symbol_like(w)
y_local = hidet.ops.relu(x_symb @ w_symb)
y_sync = hidet.ops.all_reduce(y_local, args.reduce_op)
graph = hidet.trace_from([y_local, y_sync], inputs=[x_symb, w_symb])
opt_graph = hidet.graph.optimize(graph)
compiled = opt_graph.build()
y_local, y_sync = compiled(x, w)

hidet.cuda.current_stream().synchronize()
print(f"process {rank}\nbefore allreduce:{y_local}\nafter allreduce:{y_sync}\n", end='')
atexit._run_exitfuncs()

if os.path.exists('tmp'):
os.remove('tmp')

world_size = args.n_gpus

# Barrier to ensure unique id is created
barrier = multiprocessing.Barrier(world_size)

# Create a unique id object in shared memory
shared_id = multiprocessing.Value(NcclUniqueId, lock=False)

processes = [Process(target=run, args=(world_size, i, shared_id, barrier)) for i in range(world_size)]
processes = [Process(target=run, args=(world_size, i)) for i in range(world_size)]

for p in processes:
p.start()
Expand Down
1 change: 1 addition & 0 deletions python/hidet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from . import drivers
from . import logging
from . import cuda
from . import distributed

from .version import __version__

Expand Down
15 changes: 13 additions & 2 deletions python/hidet/cuda/nccl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .comm import create_comm, NcclUniqueId, NcclDataType, NcclRedOp, comms_to_array, init_unique_id, dtype_to_nccl
from .ffi import nccl_version, nccl_library_filename
from .ffi import nccl_available, nccl_version, nccl_library_filename
from .comm import (
create_comm,
NcclUniqueId,
NcclDataType,
NcclRedOp,
comms_to_array,
create_unique_id,
dtype_to_nccl,
NcclCommunicator,
str_to_nccl_op,
NCCL_SPLIT_NOCOLOR,
)
45 changes: 38 additions & 7 deletions python/hidet/cuda/nccl/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import IntEnum
from typing import List
from typing import List, Optional
import struct

from hidet.ffi.utils import Array
from hidet.ir.type import void_p, DataType
from .ffi import nccl_runtime_api, NcclUniqueId
from hidet.cuda import Stream, current_stream
from .ffi import nccl_available, NcclUniqueId

NCCL_SPLIT_NOCOLOR = -1

if nccl_available():
from .ffi import nccl_runtime_api


class NcclDataType(IntEnum):
Expand Down Expand Up @@ -44,13 +50,20 @@ class NcclRedOp(IntEnum):
avg = 4


def str_to_nccl_op(name: str) -> NcclRedOp:
if name not in ('sum', 'prod', 'max', 'min', 'avg'):
raise RuntimeError(f"'{name}' is not a supported reduce op")
return getattr(NcclRedOp, name)


class NcclCommunicator:
def __init__(self, handle: int):
"""
Users should not call this constructor directly. Because there are two ways of creating
a new communicator: 1) using unique_id and rank ; 2) using split.
"""

if not nccl_available():
raise RuntimeError("NCCL is not available")
self._handle = handle

def __del__(self):
Expand All @@ -60,11 +73,25 @@ def __del__(self):
def handle(self):
return self._handle

def split(self):
raise NotImplementedError()
def split(self, key, color):
new_handle = nccl_runtime_api.comm_split(self._handle, color, key)
if color == NCCL_SPLIT_NOCOLOR:
return None
return NcclCommunicator(new_handle)

def all_reduce(
self, sendbuff: int, recvbuff: int, count: int, datatype: DataType, op: str, s: Optional[Stream] = None
):
if s is None:
s = current_stream()
nccl_runtime_api.all_reduce(
sendbuff, recvbuff, count, int(dtype_to_nccl(datatype)), int(str_to_nccl_op(op)), self._handle, s
)


def create_comm(nranks: int, unique_id: NcclUniqueId, rank: int) -> NcclCommunicator:
if not nccl_available():
raise RuntimeError("NCCL is not available")
handle = nccl_runtime_api.comm_init_rank(nranks, unique_id, rank)
return NcclCommunicator(handle)

Expand All @@ -76,8 +103,12 @@ def comms_to_array(comms: List[NcclCommunicator]) -> Array:
return array


def init_unique_id(unqie_id: NcclUniqueId) -> None:
nccl_runtime_api.get_unique_id(unqie_id)
def create_unique_id() -> NcclUniqueId:
if not nccl_available():
raise RuntimeError("NCCL is not available")
unique_id = NcclUniqueId()
nccl_runtime_api.get_unique_id(unique_id)
return unique_id


def dtype_to_nccl(dtype: DataType) -> NcclDataType:
Expand Down