From 5cfc2fbf543c0cdc620ab21a1037b386c86f3899 Mon Sep 17 00:00:00 2001 From: SolenoidWGT <877825076@qq.com> Date: Mon, 15 Aug 2022 13:10:09 +0000 Subject: [PATCH] feature(wgt): enable DI using torch-rpc to support GPU-p2p and RDMA-rpc 1. Add torchrpc message queue. 2. Implement buffer based on CUDA-shared-tensor to optimize the data path of torchrpc. 3. Add 'bypass_eventloop' arg in Task() and Parallel(). 4. Add thread lock in distributer.py to prevent sender and receiver competition. 5. Add message queue perf test for torchrpc, nccl, nng, shm 6. Add comm_perf_helper.py to make program timing more convenient. 7. Modified the subscribe() of class MQ, adding 'fn' parameter and 'is_once' parameter. 8. Add new DummyLock and ConditionLock type in lock_helper.py 9. Add message queues perf test. 10. Introduced a new self-hosted runner to execute cuda, multiprocess, torchrpc related tests. --- .github/workflows/unit_test.yml | 83 +++- Makefile | 16 +- codecov.yml | 7 + ding/compatibility.py | 4 + ding/data/shm_buffer.py | 134 +++++- ding/data/tests/test_shm_buffer.py | 78 +++- ding/entry/cli_ditask.py | 47 ++- .../env_manager/subprocess_env_manager.py | 35 +- ding/framework/__init__.py | 4 +- ding/framework/message_queue/README.md | 13 + ding/framework/message_queue/__init__.py | 1 + ding/framework/message_queue/mq.py | 7 +- ding/framework/message_queue/nng.py | 2 +- .../framework/message_queue/perfs/perf_nng.py | 274 ++++++++++++ .../framework/message_queue/perfs/perf_shm.py | 141 +++++++ .../message_queue/perfs/perf_torchrpc_nccl.py | 278 +++++++++++++ .../perfs/tests/test_perf_nng.py | 14 + .../perfs/tests/test_perf_shm.py | 20 + .../perfs/tests/test_perf_torchrpc_nccl.py | 18 + ding/framework/message_queue/redis.py | 4 +- .../message_queue/tests/test_torch_rpc.py | 227 ++++++++++ ding/framework/message_queue/torch_rpc.py | 391 ++++++++++++++++++ ding/framework/middleware/distributer.py | 80 ++-- .../middleware/functional/collector.py | 23 +- ding/framework/parallel.py | 258 +++++++++++- ding/framework/task.py | 69 +++- ding/torch_utils/data_helper.py | 2 + ding/utils/__init__.py | 5 +- ding/utils/comm_perf_helper.py | 145 +++++++ ding/utils/lock_helper.py | 39 ++ dizoo/atari/example/atari_dqn_dist_ddp.py | 1 - dizoo/atari/example/atari_dqn_dist_rdma.py | 51 ++- pytest.ini | 2 + 33 files changed, 2350 insertions(+), 123 deletions(-) create mode 100644 ding/framework/message_queue/README.md create mode 100644 ding/framework/message_queue/perfs/perf_nng.py create mode 100644 ding/framework/message_queue/perfs/perf_shm.py create mode 100644 ding/framework/message_queue/perfs/perf_torchrpc_nccl.py create mode 100644 ding/framework/message_queue/perfs/tests/test_perf_nng.py create mode 100644 ding/framework/message_queue/perfs/tests/test_perf_shm.py create mode 100644 ding/framework/message_queue/perfs/tests/test_perf_torchrpc_nccl.py create mode 100644 ding/framework/message_queue/tests/test_torch_rpc.py create mode 100644 ding/framework/message_queue/torch_rpc.py create mode 100644 ding/utils/comm_perf_helper.py diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index c7195d820b..c69e5fe0e6 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -11,12 +11,11 @@ jobs: if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: matrix: - python-version: [3.7, 3.8, 3.9] - + python-version: ["3.7", "3.8", "3.9"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: do_unittest @@ -41,12 +40,13 @@ jobs: if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: matrix: - python-version: [3.7, 3.8, 3.9] - + python-version: ["3.7", "3.8", "3.9"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 + env: + AGENT_TOOLSDIRECTORY: /opt/hostedtoolcache with: python-version: ${{ matrix.python-version }} - name: do_benchmark @@ -55,3 +55,70 @@ jobs: python -m pip install ".[test,k8s]" ./ding/scripts/install-k8s-tools.sh make benchmark + + test_multiprocess: + runs-on: self-hosted + if: "!contains(github.event.head_commit.message, 'ci skip')" + strategy: + matrix: + python-version: ["3.7", "3.8", "3.9"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: do_multiprocesstest + timeout-minutes: 40 + run: | + python -m pip install box2d-py + python -m pip install . + python -m pip install ".[test,k8s]" + ./ding/scripts/install-k8s-tools.sh + make multiprocesstest + + test_cuda: + runs-on: self-hosted + if: "!contains(github.event.head_commit.message, 'ci skip')" + strategy: + matrix: + python-version: ["3.7", "3.8", "3.9"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + env: + AGENT_TOOLSDIRECTORY: /opt/hostedtoolcache + with: + python-version: ${{ matrix.python-version }} + - name: do_unittest + timeout-minutes: 40 + run: | + python -m pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 + python -m pip install box2d-py + python -m pip install . + python -m pip install ".[test,k8s]" + ./ding/scripts/install-k8s-tools.sh + make cudatest + + test_mq_benchmark: + runs-on: self-hosted + if: "!contains(github.event.head_commit.message, 'ci skip')" + strategy: + matrix: + python-version: ["3.7", "3.8", "3.9"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + env: + AGENT_TOOLSDIRECTORY: /opt/hostedtoolcache + with: + python-version: ${{ matrix.python-version }} + - name: do_mqbenchmark + run: | + python -m pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 + python -m pip install . + python -m pip install ".[test,k8s]" + ./ding/scripts/install-k8s-tools.sh + make mqbenchmark \ No newline at end of file diff --git a/Makefile b/Makefile index 39810b7871..c6ead4d1ab 100644 --- a/Makefile +++ b/Makefile @@ -57,11 +57,25 @@ benchmark: --durations=0 \ -sv -m benchmark +multiprocesstest: + pytest ${TEST_DIR} \ + --cov-report=xml \ + --cov-report term-missing \ + --cov=${COV_DIR} \ + ${DURATIONS_COMMAND} \ + ${WORKERS_COMMAND} \ + -sv -m multiprocesstest + +mqbenchmark: + pytest ${TEST_DIR} \ + --durations=0 \ + -sv -m mqbenchmark + test: unittest # just for compatibility, can be changed later cpu_test: unittest algotest benchmark -all_test: unittest algotest cudatest benchmark +all_test: unittest algotest cudatest benchmark multiprocesstest format: yapf --in-place --recursive -p --verbose --style .style.yapf ${FORMAT_DIR} diff --git a/codecov.yml b/codecov.yml index 0779ada773..af3e5c97dd 100644 --- a/codecov.yml +++ b/codecov.yml @@ -6,3 +6,10 @@ coverage: target: auto threshold: 0.5% if_ci_failed: success #success, failure, error, ignore + +# fix me +# The unittests of the torchrpc module are tested by different runners and cannot be included +# in the test_unittest's coverage report. To keep CI happy, we don't count torchrpc related coverage. +ignore: + - /mnt/cache/wangguoteng/DI-engine/ding/framework/message_queue/torch_rpc.py + - /mnt/cache/wangguoteng/DI-engine/ding/framework/message_queue/perfs/* diff --git a/ding/compatibility.py b/ding/compatibility.py index dd6b1fd0da..94d37991e0 100644 --- a/ding/compatibility.py +++ b/ding/compatibility.py @@ -7,3 +7,7 @@ def torch_ge_131(): def torch_ge_180(): return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 180 + + +def torch_ge_1121(): + return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 1121 diff --git a/ding/data/shm_buffer.py b/ding/data/shm_buffer.py index b76f5d56e9..875a7210c7 100644 --- a/ding/data/shm_buffer.py +++ b/ding/data/shm_buffer.py @@ -3,6 +3,10 @@ import ctypes import numpy as np import torch +import torch.multiprocessing as mp +from functools import reduce +from ditk import logging +from abc import abstractmethod _NTYPE_TO_CTYPE = { np.bool_: ctypes.c_bool, @@ -18,8 +22,37 @@ np.float64: ctypes.c_double, } +# uint16, uint32, uint32 +_NTYPE_TO_TTYPE = { + np.bool_: torch.bool, + np.uint8: torch.uint8, + # np.uint16: torch.int16, + # np.uint32: torch.int32, + # np.uint64: torch.int64, + np.int8: torch.uint8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float32: torch.float32, + np.float64: torch.float64, +} + +_NOT_SUPPORT_NTYPE = {np.uint16: torch.int16, np.uint32: torch.int32, np.uint64: torch.int64} +_CONVERSION_TYPE = {np.uint16: np.int16, np.uint32: np.int32, np.uint64: np.int64} + + +class ShmBufferBase: + + @abstractmethod + def fill(self, src_arr: Union[np.ndarray, torch.Tensor]) -> None: + raise NotImplementedError -class ShmBuffer(): + @abstractmethod + def get(self) -> Union[np.ndarray, torch.Tensor]: + raise NotImplementedError + + +class ShmBuffer(ShmBufferBase): """ Overview: Shared memory buffer to store numpy array. @@ -78,6 +111,94 @@ def get(self) -> np.ndarray: return data +class ShmBufferCuda(ShmBufferBase): + + def __init__( + self, + dtype: Union[torch.dtype, np.dtype], + shape: Tuple[int], + ctype: Optional[type] = None, + copy_on_get: bool = True, + device: Optional[torch.device] = torch.device('cuda:0') + ) -> None: + """ + Overview: + Use torch.multiprocessing for shared tensor or ndaray between processes. + Arguments: + - dtype (Union[torch.dtype, np.dtype]): dtype of torch.tensor or numpy.ndarray. + - shape (Tuple[int]): Shape of torch.tensor or numpy.ndarray. + - ctype (type): Origin class type, e.g. np.ndarray, torch.Tensor. + - copy_on_get (bool, optional): Can be set to False only if the shared object + is a tenor, otherwise True. + - device (Optional[torch.device], optional): The GPU device where cuda-shared-tensor + is located, the default is cuda:0. + + Raises: + RuntimeError: Unsupported share type by ShmBufferCuda. + """ + if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype + self.ctype = np.ndarray + dtype = dtype.type + if dtype in _NOT_SUPPORT_NTYPE.keys(): + logging.warning( + "Torch tensor unsupport numpy type {}, attempt to do a type conversion, which may lose precision.". + format(dtype) + ) + ttype = _NOT_SUPPORT_NTYPE[dtype] + self.dtype = _CONVERSION_TYPE[dtype] + else: + ttype = _NTYPE_TO_TTYPE[dtype] + self.dtype = dtype + elif isinstance(dtype, torch.dtype): + self.ctype = torch.Tensor + ttype = dtype + else: + raise RuntimeError("The dtype parameter only supports torch.dtype and np.dtype") + + self.copy_on_get = copy_on_get + self.shape = shape + self.device = device + # We don't want the buffer to be involved in the computational graph + with torch.no_grad(): + self.buffer = torch.zeros(reduce(lambda x, y: x * y, shape), dtype=ttype, device=self.device) + + def fill(self, src_arr: Union[np.ndarray, torch.Tensor]) -> None: + if self.ctype is np.ndarray: + if src_arr.dtype.type != self.dtype: + logging.warning( + "Torch tensor unsupport numpy type {}, attempt to do a type conversion, which may lose precision.". + format(self.dtype) + ) + src_arr = src_arr.astype(self.dtype) + tensor = torch.from_numpy(src_arr) + elif self.ctype is torch.Tensor: + tensor = src_arr + else: + raise RuntimeError("Unsopport CUDA-shared-tensor input type:\"{}\"".format(type(src_arr))) + + # If the GPU-a and GPU-b are connected using nvlink, the copy is very fast. + with torch.no_grad(): + self.buffer.copy_(tensor.view(tensor.numel())) + + def get(self) -> Union[np.ndarray, torch.Tensor]: + with torch.no_grad(): + if self.ctype is np.ndarray: + # Because ShmBufferCuda use CUDA memory exchanging data between processes. + # So copy_on_get is necessary for numpy arrays. + re = self.buffer.cpu() + re = re.detach().view(self.shape).numpy() + else: + if self.copy_on_get: + re = self.buffer.clone().detach().view(self.shape) + else: + re = self.buffer.view(self.shape) + + return re + + def __del__(self): + del self.buffer + + class ShmBufferContainer(object): """ Overview: @@ -88,7 +209,8 @@ def __init__( self, dtype: Union[Dict[Any, type], type, np.dtype], shape: Union[Dict[Any, tuple], tuple], - copy_on_get: bool = True + copy_on_get: bool = True, + is_cuda_buffer: bool = False ) -> None: """ Overview: @@ -98,11 +220,15 @@ def __init__( - shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \ multiple buffers; If `tuple`, use single buffer. - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. + - is_cuda_buffer (:obj:`bool`): Whether to use pytorch CUDA shared tensor as the implementation of shm. """ if isinstance(shape, dict): - self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()} + self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get, is_cuda_buffer) for k, v in shape.items()} elif isinstance(shape, (tuple, list)): - self._data = ShmBuffer(dtype, shape, copy_on_get) + if not is_cuda_buffer: + self._data = ShmBuffer(dtype, shape, copy_on_get) + else: + self._data = ShmBufferCuda(dtype, shape, copy_on_get) else: raise RuntimeError("not support shape: {}".format(shape)) self._shape = shape diff --git a/ding/data/tests/test_shm_buffer.py b/ding/data/tests/test_shm_buffer.py index 04334b4799..6316e40b66 100644 --- a/ding/data/tests/test_shm_buffer.py +++ b/ding/data/tests/test_shm_buffer.py @@ -1,20 +1,90 @@ +from ding.data.shm_buffer import ShmBuffer, ShmBufferCuda +from ding.compatibility import torch_ge_1121 + import pytest import numpy as np import timeit -from ding.data.shm_buffer import ShmBuffer -import multiprocessing as mp +import torch +import time -def subprocess(shm_buf): +def subprocess_np_shm(shm_buf): data = np.random.rand(1024, 1024).astype(np.float32) res = timeit.repeat(lambda: shm_buf.fill(data), repeat=5, number=1000) print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res))) +def subprocess_cuda_shared_tensor(shm_buf_np, shm_buf_torch, event_run): + event_run.wait() + rtensor = shm_buf_torch.get() + assert isinstance(rtensor, torch.Tensor) + assert rtensor.device == torch.device('cuda:0') + assert rtensor.dtype == torch.float32 + assert rtensor.sum().item() == 1024 * 1024 + + rarray = shm_buf_np.get() + assert isinstance(rarray, np.ndarray) + assert rarray.dtype == np.dtype(np.float32) + assert rarray.dtype == np.dtype(np.float32) + + res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.get(), repeat=5, number=1000) + print("CUDA-shared-tensor (torch) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) + res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.get(), repeat=5, number=1000) + print("CUDA-shared-tensor (numpy) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) + + del shm_buf_np + del shm_buf_torch + + @pytest.mark.benchmark def test_shm_buffer(): + import multiprocessing as mp data = np.random.rand(1024, 1024).astype(np.float32) shm_buf = ShmBuffer(data.dtype, data.shape, copy_on_get=False) - proc = mp.Process(target=subprocess, args=[shm_buf]) + proc = mp.Process(target=subprocess_np_shm, args=[shm_buf]) proc.start() proc.join() + + +@pytest.mark.benchmark +@pytest.mark.cudatest +@pytest.mark.multiprocesstest +def test_cuda_shm(): + if torch.cuda.is_available() and torch.cuda.device_count() >= 2: + import torch.multiprocessing as mp + ctx = mp.get_context('spawn') + + event_run = ctx.Event() + shm_buf_np = ShmBufferCuda(np.dtype(np.float32), shape=(1024, 1024), copy_on_get=True) + shm_buf_torch = ShmBufferCuda(torch.float32, shape=(1024, 1024), copy_on_get=True) + proc = ctx.Process(target=subprocess_cuda_shared_tensor, args=[shm_buf_np, shm_buf_torch, event_run]) + proc.start() + + ltensor = torch.ones((1024, 1024), dtype=torch.float32).cuda(0 if torch.cuda.device_count() == 1 else 1) + larray = np.random.rand(1024, 1024).astype(np.float32) + shm_buf_torch.fill(ltensor) + shm_buf_np.fill(larray) + + res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.fill(ltensor), repeat=5, number=1000) + print("CUDA-shared-tensor (torch) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) + res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.fill(larray), repeat=5, number=1000) + print("CUDA-shared-tensor (numpy) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) + + rtensor = shm_buf_torch.get() + assert isinstance(rtensor, torch.Tensor) + assert rtensor.device == torch.device('cuda:0') + assert rtensor.shape == ltensor.shape + assert rtensor.dtype == ltensor.dtype + + rarray = shm_buf_np.get() + assert isinstance(rarray, np.ndarray) + assert larray.shape == rarray.shape + assert larray.dtype == rarray.dtype + + event_run.set() + + # Keep producer process running until all consumers exits. + proc.join() + + del shm_buf_np + del shm_buf_torch diff --git a/ding/entry/cli_ditask.py b/ding/entry/cli_ditask.py index 443fe1a6b6..29af0af2ad 100644 --- a/ding/entry/cli_ditask.py +++ b/ding/entry/cli_ditask.py @@ -57,12 +57,36 @@ def print_version(ctx: Context, param: Option, value: bool) -> None: ) @click.option("--platform-spec", type=str, help="Platform specific configure.") @click.option("--platform", type=str, help="Platform type: slurm, k8s.") -@click.option("--mq-type", type=str, default="nng", help="Class type of message queue, i.e. nng, redis.") +@click.option( + "--mq-type", + type=str, + default="nng", + help="Class type of message queue, i.e. nng, redis, torchrpc:cuda, torchrpc:cpu." +) @click.option("--redis-host", type=str, help="Redis host.") @click.option("--redis-port", type=int, help="Redis port.") @click.option("-m", "--main", type=str, help="Main function of entry module.") @click.option("--startup-interval", type=int, default=1, help="Start up interval between each task.") @click.option("--local_rank", type=int, default=0, help="Compatibility with PyTorch DDP") +@click.option( + "--init-method", + type=str, + help="[Torchrpc]: Init method both for init_rpc and init_process_group, please refer to pytorch init_method" +) +@click.option( + "--local-cuda-devices", + type=str, + help='''[Torchrpc]: [Optional] Specifies the device ranks of the GPUs used by the local process, a comma-separated + list of integers.''' +) +@click.option( + "--cuda-device-map", + type=str, + help='''[Torchrpc]: [Optional] Specify device mapping. + Ref: + Format: --cuda-device-map=__,[...] + ''' +) def cli_ditask(*args, **kwargs): return _cli_ditask(*args, **kwargs) @@ -107,9 +131,12 @@ def _cli_ditask( redis_host: str, redis_port: int, startup_interval: int, + init_method: str = None, local_rank: int = 0, platform: str = None, platform_spec: str = None, + local_cuda_devices: str = None, + cuda_device_map: str = None ): # Parse entry point all_args = locals() @@ -145,6 +172,18 @@ def _cli_ditask( if node_ids and not isinstance(node_ids, int): node_ids = node_ids.split(",") node_ids = list(map(lambda i: int(i), node_ids)) + use_cuda = False + if mq_type == "torchrpc:cuda" or mq_type == "torchrpc:cpu": + mq_type, use_cuda = mq_type.split(":") + if use_cuda == "cuda": + use_cuda = True + if local_cuda_devices: + local_cuda_devices = local_cuda_devices.split(",") + local_cuda_devices = list(map(lambda s: s.strip(), local_cuda_devices)) + if cuda_device_map: + cuda_device_map = cuda_device_map.split(",") + cuda_device_map = list(map(lambda s: s.strip(), cuda_device_map)) + Parallel.runner( n_parallel_workers=parallel_workers, ports=ports, @@ -157,5 +196,9 @@ def _cli_ditask( mq_type=mq_type, redis_host=redis_host, redis_port=redis_port, - startup_interval=startup_interval + init_method=init_method, + startup_interval=startup_interval, + use_cuda=use_cuda, + local_cuda_devices=local_cuda_devices, + cuda_device_map=cuda_device_map )(main_func) diff --git a/ding/envs/env_manager/subprocess_env_manager.py b/ding/envs/env_manager/subprocess_env_manager.py index 1648981f03..fdcc61de17 100644 --- a/ding/envs/env_manager/subprocess_env_manager.py +++ b/ding/envs/env_manager/subprocess_env_manager.py @@ -1,5 +1,6 @@ from typing import Any, Union, List, Tuple, Dict, Callable, Optional from multiprocessing import connection, get_context +# from torch.multiprocessing import connection, get_context from collections import namedtuple from ditk import logging import platform @@ -12,6 +13,7 @@ import cloudpickle import numpy as np import treetensor.numpy as tnp +import treetensor.torch as ttorch from easydict import EasyDict from types import MethodType from ding.data import ShmBufferContainer, ShmBuffer @@ -70,6 +72,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager): retry_waiting_time=0.1, # subprocess specified args shared_memory=True, + cuda_shared_memory=False, copy_on_get=True, context='spawn' if platform.system().lower() == 'windows' else 'fork', wait_num=2, @@ -97,6 +100,7 @@ def __init__( """ super().__init__(env_fn, cfg) self._shared_memory = self._cfg.shared_memory + self._cuda_shared_memory = self._cfg.cuda_shared_memory if self._shared_memory else False self._copy_on_get = self._cfg.copy_on_get self._context = self._cfg.context self._wait_num = self._cfg.wait_num @@ -134,7 +138,9 @@ def _create_state(self) -> None: shape = obs_space.shape dtype = obs_space.dtype self._obs_buffers = { - env_id: ShmBufferContainer(dtype, shape, copy_on_get=self._copy_on_get) + env_id: ShmBufferContainer( + dtype, shape, copy_on_get=self._copy_on_get, is_cuda_buffer=self._cuda_shared_memory + ) for env_id in range(self.env_num) } else: @@ -148,7 +154,11 @@ def _create_state(self) -> None: def _create_env_subprocess(self, env_id): # start a new one - ctx = get_context(self._context) + if self._cuda_shared_memory: + import torch.multiprocessing as mp + ctx = mp.get_context('spawn') + else: + ctx = get_context(self._context) self._pipe_parents[env_id], self._pipe_children[env_id] = ctx.Pipe() self._subprocesses[env_id] = ctx.Process( # target=self.worker_fn, @@ -705,6 +715,7 @@ class SyncSubprocessEnvManager(AsyncSubprocessEnvManager): retry_waiting_time=0.1, # subprocess specified args shared_memory=True, + cuda_shared_memory=False, copy_on_get=True, context='spawn' if platform.system().lower() == 'windows' else 'fork', wait_num=float("inf"), # inf mean all the environments @@ -802,7 +813,7 @@ class SubprocessEnvManagerV2(SyncSubprocessEnvManager): """ @property - def ready_obs(self) -> tnp.array: + def ready_obs(self) -> Union[tnp.array, torch.Tensor]: """ Overview: Get the ready (next) observation in ``tnp.array`` type, which is uniform for both async/sync scenarios. @@ -822,7 +833,10 @@ def ready_obs(self) -> tnp.array: ) time.sleep(0.001) sleep_count += 1 - return tnp.stack([tnp.array(self._ready_obs[i]) for i in self.ready_env]) + if not self._cuda_shared_memory: + return tnp.stack([tnp.array(self._ready_obs[i]) for i in self.ready_env]) + else: + return ttorch.stack([ttorch.tensor(self._ready_obs[i]) for i in self.ready_env]) def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]: """ @@ -846,5 +860,16 @@ def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]: # in order to call them as attribute (e.g. timestep.xxx), such as ``TimeLimit.truncated`` in cartpole info info = make_key_as_identifier(info) info = remove_illegal_item(info) - new_data.append(tnp.array({'obs': obs, 'reward': reward, 'done': done, 'info': info, 'env_id': env_id})) + if not self._cuda_shared_memory: + new_data.append(tnp.array({'obs': obs, 'reward': reward, 'done': done, 'info': info, 'env_id': env_id})) + else: + new_data.append( + ttorch.tensor({ + 'obs': obs, + 'reward': reward, + 'done': done, + 'info': info, + 'env_id': env_id + }) + ) return new_data diff --git a/ding/framework/__init__.py b/ding/framework/__init__.py index 72c23d0475..fd489588e7 100644 --- a/ding/framework/__init__.py +++ b/ding/framework/__init__.py @@ -1,6 +1,6 @@ from .context import Context, OnlineRLContext, OfflineRLContext -from .task import Task, task, VoidMiddleware -from .parallel import Parallel +from .task import Task, task, VoidMiddleware, enable_async +from .parallel import Parallel, MQType from .event_loop import EventLoop from .supervisor import Supervisor from easydict import EasyDict diff --git a/ding/framework/message_queue/README.md b/ding/framework/message_queue/README.md new file mode 100644 index 0000000000..3267dbecfd --- /dev/null +++ b/ding/framework/message_queue/README.md @@ -0,0 +1,13 @@ +# Notes on using torchrpc + +## Problems you may encounter + +Message queue of Torchrpc uses [tensorpipe](https://github.com/pytorch/tensorpipe) as a communication backend, a high-performance modular tensor-p2p communication library. However, several tensorpipe defects have been found in the test, which may make it difficult for you to use it. + +### 1. container environment + +Tensorpipe is not container aware. Processes can find themselves on the same physical machine through `/proc/sys/kernel/random/boot_id` ,but because in separated pod/container, they cannot use means of communication such as CUDA ipc. When tensorpipe finds that these communication methods cannot be used, it will report an error and exit. + +### 2. RDMA and fork subprocess + +Tensorpipe does not consider the case of calling [fork(2)](https://man7.org/linux/man-pages/man2/fork.2.html) when using RDMA. If the corresponding initialization measures are not performed when using RDMA, using fork will cause serious problems, refer to [here](https://www.rdmamojo.com/2012/05/24/ibv_fork_init/). Therefore, if you start ditask in the IB/RoCE network environment, please specify the environment variables `IBV_FORK_SAFE=1` and `RDMAV_FORK_SAFE=1` , so that ibverbs will automatically initialize fork support. \ No newline at end of file diff --git a/ding/framework/message_queue/__init__.py b/ding/framework/message_queue/__init__.py index 7cbbbcd93c..3cedbe11d7 100644 --- a/ding/framework/message_queue/__init__.py +++ b/ding/framework/message_queue/__init__.py @@ -1,3 +1,4 @@ from .mq import MQ from .redis import RedisMQ from .nng import NNGMQ +from .torch_rpc import TORCHRPCMQ, DeviceMap diff --git a/ding/framework/message_queue/mq.py b/ding/framework/message_queue/mq.py index 4386882020..37a6b61676 100644 --- a/ding/framework/message_queue/mq.py +++ b/ding/framework/message_queue/mq.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional class MQ: @@ -31,12 +31,15 @@ def publish(self, topic: str, data: bytes) -> None: """ raise NotImplementedError - def subscribe(self, topic: str) -> None: + def subscribe(self, topic: str, fn: Optional[callable] = None, is_once: Optional[bool] = False) -> None: """ Overview: Subscribe to the topic. Arguments: - topic (:obj:`str`): Topic + - fn (:obj:`Optional[callable]`): The message handler, if the communication library + implements event_loop, it can bypass Parallel() and calling this function by itself. + - is_once (:obj:`bool`): Whether Topic will only be called once. """ raise NotImplementedError diff --git a/ding/framework/message_queue/nng.py b/ding/framework/message_queue/nng.py index 379601b0ed..5298fc0a55 100644 --- a/ding/framework/message_queue/nng.py +++ b/ding/framework/message_queue/nng.py @@ -39,7 +39,7 @@ def publish(self, topic: str, data: bytes) -> None: data = topic.encode() + data self._sock.send(data) - def subscribe(self, topic: str) -> None: + def subscribe(self, topic: str, fn: Optional[callable] = None, is_once: Optional[bool] = False) -> None: return def unsubscribe(self, topic: str) -> None: diff --git a/ding/framework/message_queue/perfs/perf_nng.py b/ding/framework/message_queue/perfs/perf_nng.py new file mode 100644 index 0000000000..d597518b54 --- /dev/null +++ b/ding/framework/message_queue/perfs/perf_nng.py @@ -0,0 +1,274 @@ +import pickle +import multiprocessing as mp +import argparse +import os +import time +import torch +import numpy as np +import click +import struct + +from time import sleep +from threading import Thread +from ding.framework.message_queue.nng import NNGMQ +from ditk import logging +from ding.framework.parallel import Parallel +from ding.utils.comm_perf_helper import byte_beauty_print, time_perf_avg, print_timer_result_csv +from ding.utils import EasyTimer, WatchDog + +logging.getLogger().setLevel(logging.INFO) +REPEAT = 10 +LENGTH = 5 +EXP_NUMS = 2 +UNIT_SIZE_LIST = [64, 1024, 64 * 1024, 512 * 1024, 2 * 1024 * 1024] + + +@click.command(context_settings=dict(help_option_names=['-h', '--help'])) +@click.option("--ports", type=str, default="50515") +@click.option("--attach-to", type=str, help="The addresses to connect to.") +@click.option("--address", type=str, help="The address to listen to (without port).") +@click.option("--labels", type=str, help="Labels.") +@click.option("--node-ids", type=str, help="Candidate node ids.") +def handle_args(*args, **kwargs): + return nng_perf_main(*args, **kwargs) + + +def pack_time(data, value): + if value: + return struct.pack('d', value) + "::".encode() + data + else: + return struct.pack('d', value) + + +def unpack_time(value): + return struct.unpack('=d', value)[0] + + +def nng_dist_main(labels, node_id, listen_to, attach_to, *arg, **kwargs) -> None: + """ + Overview: + Since nng message reception may be out of order, and nng + does not have a handshake, the sender may start + sending messages and timing before the receiver is ready. + So this function does the corresponding work. + """ + mq = NNGMQ(listen_to=listen_to, attach_to=attach_to) + mq.listen() + label = labels.pop() + rank = 0 + future_dict = dict() + start_tag = [] + finish_tag = [] + + def send_t(topic, data=None): + try: + if not data: + data = [0, 0] + data = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) + mq.publish(topic, data) + logging.debug("send topic {}".format(topic)) + except Exception as e: + logging.error("send error at rank:{} label:\"{}\", topic:\"{}\", error: {}".format(rank, label, topic, e)) + + def recv_loop(): + while True: + topic, data = mq.recv() + if topic == "z": + # perf_nng_detail recv callback. + timestamps, data = data.split(b"::", maxsplit=1) + h2d_timer = EasyTimer(cuda=True) + pickle_timer = EasyTimer(cuda=False) + + with pickle_timer: + data = pickle.loads(data) + data, idx = data[0], data[1] + + with h2d_timer: + data = data.cuda(0) + + data = pickle.dumps([timestamps, idx], protocol=pickle.HIGHEST_PROTOCOL) + time_res = pack_time(data, pickle_timer.value) + time_res = pack_time(time_res, h2d_timer.value) + + mq.publish("k", time_res) + continue + elif topic == "k": + # perf_nng_detail send callback. + h2d_time, pickle_time, data = data.split(b"::", maxsplit=2) + data = pickle.loads(data) + timestamps, idx = data[0], data[1] + future_dict['perf_finsh'] = (unpack_time(h2d_time), unpack_time(pickle_time), unpack_time(timestamps)) + future_dict[idx] = 1 + continue + else: + # Callback functions for other tests. + data = pickle.loads(data) + data, idx = data[0], data[1] + if topic == "t": + assert isinstance(data, torch.Tensor) + data = data.cuda(0) + torch.cuda.synchronize(0) + pass + elif topic == "d": + assert isinstance(data, dict) + for k, v in data.items(): + data[k] = v.cuda(0) + torch.cuda.synchronize(0) + elif topic == "a": + if idx not in future_dict.keys(): + raise RuntimeError("Unkown idx") + future_dict[idx] = 1 + continue + elif topic == "s": + if label == 'collector': + send_t("s") + elif label == 'learner': + start_tag.append(1) + continue + elif topic == "f": + finish_tag.append(1) + return + else: + raise RuntimeError("Unkown topic") + + send_t("a", ["", idx]) + + def irendezvous(): + timeout_killer = WatchDog(3) + timeout_killer.start() + send_t("s") + while len(start_tag) == 0: + time.sleep(0.05) + timeout_killer.stop() + + listen_thread = Thread(target=recv_loop, name="recv_loop", daemon=True) + listen_thread.start() + + if label == 'learner': + while True: + try: + irendezvous() + except Exception as e: + logging.warning("timeout for irendezvous") + else: + break + + if label == 'learner': + + for size in UNIT_SIZE_LIST: + unit_size = size * LENGTH + gpu_data = torch.ones(unit_size).cuda(rank) + time_list = [list() for i in range(EXP_NUMS)] + size_lists = [[size] for i in range(LENGTH)] + send_func_list = [] + logging.info("Data size: {:.2f} {}".format(*byte_beauty_print(unit_size * 4))) + tensor_dict = dict() + for j, size_list in enumerate(size_lists): + tensor_dict[str(j)] = torch.ones(size_list).cuda(rank) + + @time_perf_avg(1, REPEAT, cuda=True) + def nng_tensor_sender_1(idx): + future_dict[idx] = 0 + send_t("t", [gpu_data.cpu(), idx]) + while future_dict[idx] == 0: + time.sleep(0.03) + + @time_perf_avg(1, REPEAT, cuda=True) + def nng_tensor_sender_2(idx): + tmp_dict = dict() + future_dict[idx] = 0 + for key, value in tensor_dict.items(): + tmp_dict[key] = value.cpu() + send_t("d", [tmp_dict, idx]) + while future_dict[idx] == 0: + time.sleep(0.03) + + def perf_nng_detail(idx): + future_dict[idx] = 0 + h2d_timer = EasyTimer(cuda=True) + pickle_timer = EasyTimer(cuda=False) + + with h2d_timer: + data = gpu_data.cpu() + + with pickle_timer: + data = pickle.dumps([data, idx], protocol=pickle.HIGHEST_PROTOCOL) + + data = pack_time(data, time.time()) + mq.publish("z", data) + + while future_dict[idx] == 0: + time.sleep(0.03) + + peer_h2d_time, peer_pickle_time, timestamps = future_dict['perf_finsh'] + total_time = time.time() - timestamps + # Serialization time + pickle_time = peer_pickle_time + pickle_timer.value + # H2D/D2H time + pcie_time = peer_h2d_time + h2d_timer.value + # TCP I/O time + IO_time = total_time - pickle_time - pcie_time + logging.info( + "Detailed: total:[{:.4f}]ms, pickle:[{:.4f}]ms, H2D/D2H:[{:.4f}]ms, I/O:[{:.4f}]ms".format( + total_time, pickle_time, pcie_time, IO_time + ) + ) + # print("{:.4f}, {:.4f}, {:.4f}, {:.4f}".format(total_time, pickle_time, pcie_time, IO_time)) + + send_func_list.append(nng_tensor_sender_1) + send_func_list.append(nng_tensor_sender_2) + + for i in range(len(send_func_list)): + for j in range(REPEAT): + send_func_list[i](j, i + j) + + # Determine the time-consuming of each stage of nng. + perf_nng_detail(0) + + # Do some proper cleanup to prevent cuda memory overflow + torch.cuda.empty_cache() + + if label == 'learner': + send_t("f") + finish_tag.append(1) + + while len(finish_tag) == 0: + time.sleep(0.1) + + print_timer_result_csv() + + +def nng_perf_main(ports: str, attach_to: str, address: str, labels: str, node_ids: str): + if not isinstance(ports, int): + ports = ports.split(",") + ports = list(map(lambda i: int(i), ports)) + ports = ports[0] if len(ports) == 1 else ports + if attach_to: + attach_to = attach_to.split(",") + attach_to = list(map(lambda s: s.strip(), attach_to)) + if labels: + labels = labels.split(",") + labels = set(map(lambda s: s.strip(), labels)) + if node_ids and not isinstance(node_ids, int): + node_ids = node_ids.split(",") + node_ids = list(map(lambda i: int(i), node_ids)) + + runner_params = Parallel._nng_args_parser( + n_parallel_workers=1, + ports=ports, + protocol="tcp", + attach_to=attach_to, + address=address, + labels=labels, + node_ids=node_ids, + ) + logging.debug(runner_params) + nng_dist_main(**runner_params[0]) + + +# Usages: +# CUDA_VISIBLE_DEVICES=0 python perf_nng.py --node-ids 0 --labels learner --ports 12345 --address 0.0.0.0 +# CUDA_VISIBLE_DEVICES=1 python perf_nng.py --node-ids 1 --labels collector --address 127.0.0.1 \ +# --ports 12355 --attach-to tcp://0.0.0.0:12345 +if __name__ == "__main__": + handle_args() diff --git a/ding/framework/message_queue/perfs/perf_shm.py b/ding/framework/message_queue/perfs/perf_shm.py new file mode 100644 index 0000000000..234f49213b --- /dev/null +++ b/ding/framework/message_queue/perfs/perf_shm.py @@ -0,0 +1,141 @@ +from typing import TYPE_CHECKING, Any, List, Union, Dict, Optional, Callable + +from ditk import logging +from ding.framework.supervisor import RecvPayload, SendPayload, Supervisor, ChildType +from ding.envs.env_manager.subprocess_env_manager import ShmBufferContainer, ShmBuffer +from ding.utils.comm_perf_helper import tensor_size_beauty_print, byte_beauty_print, \ + dtype_2_byte, TENSOR_SIZE_LIST, print_timer_result_csv + +import torch +import numpy as np +import time +import argparse + +LENGTH = 5 +REPEAT = 10 +UNIT_SIZE_LIST = [64, 1024, 64 * 1024, 512 * 1024, 2 * 1024 * 1024] +logging.getLogger().setLevel(logging.INFO) + + +def shm_callback(payload: RecvPayload, buffers: Any): + # Step4: shared memory -> np.array + np_tensor = buffers[payload.data["idx"]].get() + # Step5: np.array -> cpu tensor + tensor = torch.from_numpy(np_tensor) + # Step6: cpu tensor -> gpu tensor + tensor = tensor.cuda(0) + torch.cuda.synchronize(0) + + +def cuda_shm_callback(payload: RecvPayload, buffers: Any): + # Step2: gpu shared tensor -> gpu tensor + tensor = buffers[payload.data["idx"]].get() + assert tensor.device == torch.device('cuda:0') + # Step3: gpu tensor(cuda:0) -> gpu tensor(cuda:1) + tensor = tensor.to(1) + torch.cuda.synchronize(1) + assert tensor.device == torch.device('cuda:1') + + +class Recvier: + + def step(self, idx: int, __start_time): + return {"idx": idx, "start_time": __start_time} + + +class ShmSupervisor(Supervisor): + + def __init__(self, gpu_tensors, buffers, ctx, is_cuda_buffer): + super().__init__(type_=ChildType.PROCESS, mp_ctx=ctx) + self.gpu_tensors = gpu_tensors + self.buffers = buffers + self.time_list = [] + self._time_list = [] + self._is_cuda_buffer = is_cuda_buffer + if not is_cuda_buffer: + _shm_callback = shm_callback + else: + _shm_callback = cuda_shm_callback + self.register(Recvier, shm_buffer=self.buffers, shm_callback=_shm_callback) + super().start_link() + + def _send_recv_callback(self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None): + idx = payload.data["idx"] + __start_time = payload.data["start_time"] + __end_time = time.time() + self.time_list.append(float(__end_time - __start_time) * 1000.0) + + def step(self): + # Do not use Queue to send large data, use shm. + for i, size in enumerate(UNIT_SIZE_LIST): + for j in range(REPEAT): + __start_time = time.time() + + if not self._is_cuda_buffer: + # Numpy shm buffer: + # Step1: gpu tensor -> cpu tensor + tensor = self.gpu_tensors[i].cpu() + # Step2: cpu tensor-> np.array + np_tensor = tensor.numpy() + # Step3: np.array -> shared memory + self.buffers[i].fill(np_tensor) + else: + # Cuda shared tensor + # Step1: gpu tensor -> gpu shared tensor + self.buffers[i].fill(self.gpu_tensors[i]) + + payload = SendPayload(proc_id=0, method="step", args=[i, __start_time]) + send_payloads = [payload] + + self.send(payload) + self.recv_all(send_payloads, ignore_err=True, callback=self._send_recv_callback) + + _avg_time = sum(self.time_list) / len(self.time_list) + self._time_list.append(_avg_time) + self.time_list.clear() + logging.info( + "Data size {:.2f} {} , repeat {}, avg RTT {:.4f} ms".format( + *byte_beauty_print(UNIT_SIZE_LIST[i] * 4 * LENGTH), REPEAT, _avg_time + ) + ) + + for t in self._time_list: + print("{:.4f},".format(t), end="") + print("") + + +def shm_perf_main(test_type: str): + gpu_tensors = list() + buffers = dict() + + if test_type == "shm": + import multiprocessing as mp + use_cuda_buffer = False + elif test_type == "cuda_ipc": + use_cuda_buffer = True + import torch.multiprocessing as mp + + ctx = mp.get_context('spawn') + + for i, size in enumerate(UNIT_SIZE_LIST): + unit_size = size * LENGTH + gpu_tensors.append(torch.ones(unit_size).cuda(0)) + if not use_cuda_buffer: + buffers[i] = ShmBufferContainer(np.float32, (unit_size, ), copy_on_get=True, is_cuda_buffer=False) + else: + buffers[i] = ShmBufferContainer(torch.float32, (unit_size, ), copy_on_get=True, is_cuda_buffer=True) + + sv = ShmSupervisor( + gpu_tensors=gpu_tensors, buffers=buffers, ctx=mp.get_context('spawn'), is_cuda_buffer=use_cuda_buffer + ) + sv.step() + del sv + + +# Usages: +# python perf_shm.py --test_type ["shm"|"cuda_ipc"] +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Test torch rpc') + parser.add_argument('--test_type', type=str) + args, _ = parser.parse_known_args() + shm_perf_main(args.test_type) diff --git a/ding/framework/message_queue/perfs/perf_torchrpc_nccl.py b/ding/framework/message_queue/perfs/perf_torchrpc_nccl.py new file mode 100644 index 0000000000..67fbb73e46 --- /dev/null +++ b/ding/framework/message_queue/perfs/perf_torchrpc_nccl.py @@ -0,0 +1,278 @@ +import time +import torch +import os +import argparse +import torch.distributed as dist +import treetensor.torch as ttorch + +from dataclasses import dataclass +from queue import Empty +from typing import TYPE_CHECKING, List, Dict, Union +from ditk import logging + +from ding.utils.data.structure.lifo_deque import LifoDeque +from ding.framework.message_queue.torch_rpc import DeviceMap, TORCHRPCMQ, RPCEvent +from ding.utils.comm_perf_helper import tensor_size_beauty_print, byte_beauty_print, \ + dtype_2_byte, DO_PERF, time_perf_avg, time_perf_once, print_timer_result_csv + +LENGTH = 5 +REPEAT = 2 +MAX_EXP_NUMS = 10 +UNIT_SIZE_LIST = [64, 1024, 64 * 1024, 512 * 1024, 2 * 1024 * 1024] + + +@dataclass +class SendInfo: + send_map_dict: Dict = None + sending_flag: int = 0 + + +# Global vars definition is here: +global mq +global global_send_info_dict +mq = None + +global_send_info_dict = dict() + + +def remote_mq_entrance(topic, *args, **kwargs): + global mq + mq.rpc_event_router(topic, *args, **kwargs) + + +def dict_tensor_send_by_key( + key: str, tensor_id: int, tensor: torch.Tensor, nums: int, send_id: int, use_cuda: bool +) -> None: + """ + Overview: + For data structures that use dict to store tensor, such as dict[key:tensor] or treetensor, + this function can be used. Each key is transmitted using one rpc, and the rpc transmission + of each key is asynchronous. + Arguments: + - key (str): Key in dict. + - tensor_id (int): The sending tensor ID during one dict/treetensor rpc transmission. + - tensor (torch.tensor): The tensor to be sent. + - nums (int): The total number of sent tensors. + - send_id (int): The ID of this dict/treetensor rpc transmission. + """ + global global_send_info_dict + send_info_dict = global_send_info_dict + send_info = None + + assert isinstance(key, str) + assert isinstance(tensor_id, int) + assert isinstance(tensor, torch.Tensor) + assert isinstance(nums, int) + assert isinstance(send_id, int) + assert isinstance(use_cuda, bool) + + if tensor_id == 0: + send_info = SendInfo() + send_info.send_map_dict = dict() + send_info_dict[send_id] = send_info + else: + while True: + if send_id in send_info_dict.keys(): + send_info = send_info_dict[send_id] + if send_info is not None: + break + + assert isinstance(send_info, SendInfo) + + if key in send_info.send_map_dict.keys(): + raise RuntimeError("Multiple state_dict's key \"{}\" received!".format(key)) + + send_info.send_map_dict[key] = tensor + + if tensor_id == nums - 1: + while len(send_info.send_map_dict) != nums: + time.sleep(0.01) + + send_info_dict.clear() + if use_cuda: + torch.cuda.synchronize(0) + return + + +def send_dummy(playload: Union[torch.Tensor, Dict], use_cuda: bool, *args) -> None: + assert isinstance(use_cuda, bool) + if use_cuda: + torch.cuda.synchronize(0) + return + + +def dict_tensor_send(mq: TORCHRPCMQ, state_dict: Dict, send_id: int, use_cuda: bool) -> None: + future_list = [] + for tensor_id, (key, value) in enumerate(state_dict.items()): + future_list.append(mq.publish("DICT_TENSOR_SEND", key, tensor_id, value, len(state_dict), send_id, use_cuda)) + + for future in future_list: + future.wait() + + +def perf_torch_rpc(use_cuda=True): + global LENGTH + global UNIT_SIZE_LIST + if use_cuda: + device = "cuda:0" + else: + device = "cpu" + + for i, unit_size in enumerate(UNIT_SIZE_LIST): + unit_tensor = torch.ones([unit_size * LENGTH]).to(device) + tensor_dict = {} + for j in range(LENGTH): + tensor_dict[str(j)] = torch.ones(unit_size).to(device) + + if use_cuda: + torch.cuda.synchronize(0) + + @time_perf_avg(1, REPEAT, cuda=use_cuda) + def one_shot_rpc(): + dict_tensor_send(mq, {'test': unit_tensor}, i, use_cuda) + + @time_perf_avg(1, REPEAT, cuda=use_cuda) + def one_shot_rpc_with_dict(): + dict_tensor_send(mq, tensor_dict, i, use_cuda) + + @time_perf_avg(1, REPEAT, cuda=use_cuda) + def split_chunk_rpc(): + re = mq.publish(RPCEvent.CUSTOM_FUNCRION_RPC, {'test': unit_tensor}, use_cuda, custom_method=send_dummy) + re.wait() + + @time_perf_avg(1, REPEAT, cuda=use_cuda) + def split_chunk_rpc_with_dict(): + re = mq.publish(RPCEvent.CUSTOM_FUNCRION_RPC, tensor_dict, use_cuda, custom_method=send_dummy) + re.wait() + + logging.debug("Size {:.2f} {}".format(*byte_beauty_print(unit_size * LENGTH * 4))) + + for idx in range(REPEAT): + one_shot_rpc(idx) + one_shot_rpc_with_dict(idx) + split_chunk_rpc(idx) + split_chunk_rpc_with_dict(idx) + + if use_cuda: + torch.cuda.empty_cache() + + +def perf_nccl(global_rank: int, use_cuda=True): + if use_cuda: + device = "cuda:0" + else: + device = "cpu" + ack_tensor = torch.ones(10).to(device) + + if global_rank == 0: + # Warm up recving + dist.recv(tensor=ack_tensor, src=1) + if use_cuda: + torch.cuda.synchronize(0) + + for i, unit_size in enumerate(UNIT_SIZE_LIST): + payload = torch.ones([unit_size * LENGTH]).to(device) + + @time_perf_avg(1, REPEAT, cuda=True) + def test_case_nccl(payload): + dist.send(tensor=payload, dst=1, tag=i) + + logging.debug("Size {:.2f} {}".format(*byte_beauty_print(unit_size * LENGTH * 4))) + + for idx in range(REPEAT): + test_case_nccl(idx, payload) + else: + # Warm up sending + dist.send(tensor=ack_tensor, dst=0) + if use_cuda: + torch.cuda.synchronize(0) + + for i, unit_size in enumerate(UNIT_SIZE_LIST): + recvbuffer = torch.ones([unit_size * LENGTH]).to(device) + for j in range(REPEAT): + dist.recv(tensor=recvbuffer, src=0, tag=i) + if use_cuda: + torch.cuda.synchronize(0) + + +def rpc_model_exchanger(rank: int, init_method: str, test_nccl: bool = False, use_cuda: bool = True): + global mq + global dict_tensor_send_by_key + global remote_mq_entrance + from ding.framework.parallel import Parallel + + logging.getLogger().setLevel(logging.DEBUG) + if test_nccl: + dist.init_process_group("nccl", rank=rank, world_size=2, init_method=init_method) + params = Parallel._torchrpc_args_parser( + n_parallel_workers=1, + attach_to=[1] if rank == 0 else [], + node_ids=[rank], + init_method=init_method, + use_cuda=use_cuda, + async_rpc=True, + async_backend_polling=False, + remote_parallel_entrance=remote_mq_entrance + )[0] + logging.debug(params) + mq = TORCHRPCMQ(**params) + mq.show_device_maps() + + # Because the dict_tensor_send_by_key() relies on global variables, we have to register it. + mq.subscribe("DICT_TENSOR_SEND", dict_tensor_send_by_key) + mq.listen() + + # In order to prevent deadlock caused by mixed use of "torch.cuda.synchronize" between + # nccl and torchrpc, we test the two backend separately. + if rank == 1: + # Receiver ready for testing nccl + if test_nccl: + perf_nccl(rank) + # Receiver join to wait sender to send shutdown signal. + mq.wait_for_shutdown() + elif rank == 0: + # Sender test torch rpc. + perf_torch_rpc(use_cuda=use_cuda) + # Sender test nccl. + if test_nccl: + perf_nccl(rank) + # Print test results. + print_timer_result_csv() + # Sender send finish signal. + mq.require_to_shutdown("Node_1") + # Sender clean resources. + mq.stop() + + +# Usage: +# CUDA_VISIBLE_DEVICES=0 python perf_torchrpc_nccl.py --rank=0 +# CUDA_VISIBLE_DEVICES=1 python perf_torchrpc_nccl.py --rank=1 +# +# Note: +# If you are in a container, please ensure that your /dev/shm is large enough. +# If there is a strange core or bug, please check if /dev/shm is full. +# If so, please try to clear it manually: +# /dev/shm/nccl* +# /dev/shm/cuda.shm.* +# /dev/shm/torch_* +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Test torch rpc') + parser.add_argument('--rank', type=int) + parser.add_argument('--init-method', type=str, default="tcp://127.0.0.1:12347") + parser.add_argument('--test_nccl', type=bool, default=False) + parser.add_argument('--use_cuda', type=bool, default=False) + args, _ = parser.parse_known_args() + + if args.use_cuda: + if "CUDA_VISIBLE_DEVICES" in os.environ: + logging.info("CUDA_VISIBLE_DEVICES: {}".format(os.environ['CUDA_VISIBLE_DEVICES'])) + else: + logging.info("Not set CUDA_VISIBLE_DEVICES!") + + logging.info( + "CUDA is enable:{}, nums of GPU: {}, current device: {}".format( + torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.current_device() + ) + ) + + rpc_model_exchanger(args.rank, args.init_method, args.test_nccl, args.use_cuda) diff --git a/ding/framework/message_queue/perfs/tests/test_perf_nng.py b/ding/framework/message_queue/perfs/tests/test_perf_nng.py new file mode 100644 index 0000000000..343d5d080d --- /dev/null +++ b/ding/framework/message_queue/perfs/tests/test_perf_nng.py @@ -0,0 +1,14 @@ +from ding.framework.message_queue.perfs.perf_nng import nng_perf_main +import multiprocessing as mp +import pytest + + +@pytest.mark.mqbenchmark +@pytest.mark.multiprocesstest +def test_nng(): + params = [ + ("12376", None, "127.0.0.1", "learner", "0"), ("12378", "tcp://127.0.0.1:12376", "127.0.0.1", "collector", "1") + ] + ctx = mp.get_context("spawn") + with ctx.Pool(processes=2) as pool: + pool.starmap(nng_perf_main, params) diff --git a/ding/framework/message_queue/perfs/tests/test_perf_shm.py b/ding/framework/message_queue/perfs/tests/test_perf_shm.py new file mode 100644 index 0000000000..2be1a2a047 --- /dev/null +++ b/ding/framework/message_queue/perfs/tests/test_perf_shm.py @@ -0,0 +1,20 @@ +from ding.framework.message_queue.perfs.perf_shm import shm_perf_main +import multiprocessing as mp +import pytest +import torch + + +@pytest.mark.mqbenchmark +@pytest.mark.cudatest +@pytest.mark.multiprocesstest +def test_shm_numpy_shm(): + if torch.cuda.is_available(): + shm_perf_main("shm") + + +@pytest.mark.mqbenchmark +@pytest.mark.cudatest +@pytest.mark.multiprocesstest +def test_shm_cuda_shared_tensor(): + if torch.cuda.is_available() and torch.cuda.device_count() >= 2: + shm_perf_main("cuda_ipc") diff --git a/ding/framework/message_queue/perfs/tests/test_perf_torchrpc_nccl.py b/ding/framework/message_queue/perfs/tests/test_perf_torchrpc_nccl.py new file mode 100644 index 0000000000..8af00e8a2a --- /dev/null +++ b/ding/framework/message_queue/perfs/tests/test_perf_torchrpc_nccl.py @@ -0,0 +1,18 @@ +from ding.framework.message_queue.perfs.perf_torchrpc_nccl import rpc_model_exchanger +from ding.compatibility import torch_ge_1121 +import multiprocessing as mp +import pytest +import torch +import platform + + +@pytest.mark.mqbenchmark +@pytest.mark.cudatest +@pytest.mark.multiprocesstest +def test_perf_torchrpc_nccl(): + if platform.system().lower() != 'windows' and torch.cuda.is_available(): + if torch_ge_1121() and torch.cuda.device_count() >= 2: + params = [(0, "tcp://127.0.0.1:12387", False, True), (1, "tcp://127.0.0.1:12387", False, True)] + ctx = mp.get_context("spawn") + with ctx.Pool(processes=2) as pool: + pool.starmap(rpc_model_exchanger, params) diff --git a/ding/framework/message_queue/redis.py b/ding/framework/message_queue/redis.py index 9cbf10e8a6..69860e3242 100644 --- a/ding/framework/message_queue/redis.py +++ b/ding/framework/message_queue/redis.py @@ -1,7 +1,7 @@ import uuid from ditk import logging from time import sleep -from typing import Tuple +from typing import Tuple, Optional import redis from ding.framework.message_queue.mq import MQ @@ -34,7 +34,7 @@ def publish(self, topic: str, data: bytes) -> None: data = self._id + b"::" + data self._client.publish(topic, data) - def subscribe(self, topic: str) -> None: + def subscribe(self, topic: str, fn: Optional[callable] = None, is_once: Optional[bool] = False) -> None: self._sub.subscribe(topic) def unsubscribe(self, topic: str) -> None: diff --git a/ding/framework/message_queue/tests/test_torch_rpc.py b/ding/framework/message_queue/tests/test_torch_rpc.py new file mode 100644 index 0000000000..1adf979021 --- /dev/null +++ b/ding/framework/message_queue/tests/test_torch_rpc.py @@ -0,0 +1,227 @@ +from ding.framework.message_queue.torch_rpc import DeviceMap, TORCHRPCMQ, DEFAULT_DEVICE_MAP_NUMS +from torch.distributed import rpc +from multiprocessing import Pool, get_context +from ding.compatibility import torch_ge_1121 +from ditk import logging + +import pytest +import torch +import platform +import time + +mq = None +recv_tensor_list = [None, None, None, None] + + +def remote_mq_entrance(topic, *args, **kwargs): + global mq + mq.rpc_event_router(topic, *args, **kwargs) + + +def torchrpc(rank): + global mq + global recv_tensor_list + mq = None + recv_tensor_list = [None, None, None, None] + logging.getLogger().setLevel(logging.DEBUG) + name_list = ["A", "B", "C", "D"] + + if rank == 0: + attach_to = name_list[1:] + else: + attach_to = None + + mq = TORCHRPCMQ( + rpc_name=name_list[rank], + global_rank=rank, + init_method="tcp://127.0.0.1:12398", + remote_parallel_entrance=remote_mq_entrance, + attach_to=attach_to, + async_rpc=False, + use_cuda=False + ) + + def fn1(tensor: torch.Tensor) -> None: + global recv_tensor_list + global mq + recv_tensor_list[0] = tensor + assert recv_tensor_list[0].sum().item() == 1000 + mq.publish("RANK_N_SEND", torch.ones(10), mq.global_rank) + + def fn2(tensor: torch.Tensor, rank) -> None: + global recv_tensor_list + recv_tensor_list[rank] = tensor + assert recv_tensor_list[rank].sum().item() == 10 + + mq.subscribe(topic="RANK_0_SEND", fn=fn1) + mq.subscribe(topic="RANK_N_SEND", fn=fn2) + mq.listen() + + if rank == 0: + mq.publish("RANK_0_SEND", torch.ones(1000)) + + mq._rendezvous_until_world_size(4) + all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() + rpc.api._barrier([worker.name for worker in all_worker_info]) + + mq.unsubscribe("RANK_0_SEND") + assert "RANK_0_SEND" not in mq._rpc_events + + if rank == 0: + mq.publish("RANK_0_SEND", torch.ones(1000)) + + mq._rendezvous_until_world_size(4) + rpc.api._barrier(name_list) + mq.stop() + + +def torchrpc_cuda(rank): + global mq + global recv_tensor_list + mq = None + recv_tensor_list = [None, None, None, None] + name_list = ["A", "B"] + logging.getLogger().setLevel(logging.DEBUG) + + if rank == 0: + attach_to = name_list[1:] + else: + attach_to = None + + peer_rank = int(rank == 0) or 0 + peer_name = name_list[peer_rank] + device_map = DeviceMap(rank, [peer_name], [rank], [peer_rank]) + logging.debug(device_map) + + mq = TORCHRPCMQ( + rpc_name=name_list[rank], + global_rank=rank, + init_method="tcp://127.0.0.1:12390", + remote_parallel_entrance=remote_mq_entrance, + attach_to=attach_to, + device_maps=device_map, + async_rpc=False, + cuda_device=rank, + use_cuda=True + ) + + def fn1(tensor: torch.Tensor) -> None: + global recv_tensor_list + global mq + recv_tensor_list[0] = tensor + assert recv_tensor_list[0].sum().item() == 777 + assert recv_tensor_list[0].device == torch.device(1) + + mq.subscribe(topic="RANK_0_SEND", fn=fn1) + mq.listen() + + if rank == 0: + mq.publish("RANK_0_SEND", torch.ones(777).cuda(0)) + + mq._rendezvous_until_world_size(2) + all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() + rpc.api._barrier([worker.name for worker in all_worker_info]) + mq.stop() + + +def torchrpc_args_parser(rank): + global mq + global recv_tensor_list + from ding.framework.parallel import Parallel + logging.getLogger().setLevel(logging.DEBUG) + + params = Parallel._torchrpc_args_parser( + n_parallel_workers=1, + attach_to=[], + node_ids=[0], + init_method="tcp://127.0.0.1:12399", + use_cuda=True, + local_cuda_devices=None, + cuda_device_map=None + )[0] + + logging.debug(params) + + # 1. If attach_to is empty, init_rpc will not block. + mq = TORCHRPCMQ(**params) + mq.listen() + assert mq._running + mq.stop() + assert not mq._running + logging.debug("[Pass] 1. If attach_to is empty, init_rpc will not block.") + + # 2. n_parallel_workers != len(node_ids) + try: + Parallel._torchrpc_args_parser(n_parallel_workers=999, attach_to=[], node_ids=[1, 2])[0] + except RuntimeError as e: + logging.debug("[Pass] 2. n_parallel_workers != len(node_ids).") + else: + assert False + + # 3. len(local_cuda_devices) != n_parallel_workers + try: + Parallel._torchrpc_args_parser(n_parallel_workers=8, node_ids=[1], local_cuda_devices=[1, 2, 3])[0] + except RuntimeError as e: + logging.debug("[Pass] 3. len(local_cuda_devices) != n_parallel_workers.") + else: + assert False + + # 4. n_parallel_workers > gpu_nums + # TODO(wgt): Support spwan mode to start torchrpc process using CPU/CUDA and CPU only. + try: + Parallel._torchrpc_args_parser(n_parallel_workers=999, node_ids=[1], use_cuda=True)[0] + except RuntimeError as e: + logging.debug("[Pass] 4. n_parallel_workers > gpu_nums.") + else: + assert False + + # 5. Set custom device map. + params = Parallel._torchrpc_args_parser( + n_parallel_workers=1, node_ids=[1], cuda_device_map=["0_0_0", "0_1_2", "1_1_4"] + )[0] + assert params['device_maps'].peer_name_list == ["Node_0", "Node_0", "Node_1"] + assert params['device_maps'].our_device_list == [0, 1, 1] + assert params['device_maps'].peer_device_list == [0, 2, 4] + # logging.debug(params['device_maps']) + logging.debug("[Pass] 5. Set custom device map.") + + # 6. Set n_parallel_workers > 1 + params = Parallel._torchrpc_args_parser(n_parallel_workers=8, node_ids=[1]) + assert len(params) == 8 + assert params[7]['node_id'] == 8 + assert params[0]['use_cuda'] is False + assert params[0]['device_maps'] is None + assert params[0]['cuda_device'] is None + + if torch.cuda.device_count() >= 2: + params = Parallel._torchrpc_args_parser(n_parallel_workers=2, node_ids=[1], use_cuda=True) + assert params[0]['use_cuda'] + assert len(params[0]['device_maps'].peer_name_list) == DEFAULT_DEVICE_MAP_NUMS - 1 + logging.debug("[Pass] 6. Set n_parallel_workers > 1.") + + +@pytest.mark.multiprocesstest +def test_torchrpc(): + ctx = get_context("spawn") + if platform.system().lower() != 'windows' and torch_ge_1121(): + with ctx.Pool(processes=4) as pool: + pool.map(torchrpc, range(4)) + + +@pytest.mark.cudatest +@pytest.mark.multiprocesstest +def test_torchrpc_cuda(): + if platform.system().lower() != 'windows': + if torch_ge_1121() and torch.cuda.is_available() and torch.cuda.device_count() >= 2: + ctx = get_context("spawn") + with ctx.Pool(processes=2) as pool: + pool.map(torchrpc_cuda, range(2)) + + +@pytest.mark.cudatest +@pytest.mark.multiprocesstest +def test_torchrpc_parser(): + if platform.system().lower() != 'windows' and torch_ge_1121() and torch.cuda.is_available(): + ctx = get_context("spawn") + with ctx.Pool(processes=1) as pool: + pool.map(torchrpc_args_parser, range(1)) diff --git a/ding/framework/message_queue/torch_rpc.py b/ding/framework/message_queue/torch_rpc.py new file mode 100644 index 0000000000..cee70c8bfb --- /dev/null +++ b/ding/framework/message_queue/torch_rpc.py @@ -0,0 +1,391 @@ +from ding.framework.message_queue.mq import MQ +from ding.utils import MQ_REGISTRY +from ditk import logging +from ding.utils import LockContext, LockContextType + +from typing import List, Optional, Tuple, Dict, Any, Union, Callable +from threading import Thread +from enum import Enum + +from torch.distributed import rpc + +import os +import time +import queue +import torch +import platform + +if platform.system().lower() != 'windows': + from torch.distributed.rpc import TensorPipeRpcBackendOptions + +DEFAULT_DEVICE_MAP_NUMS = 12 + + +# About RPCEvent: +# RPCEvent stores events that are not related to RL train logic. +# Private events use "int" to represent topic in order to reduce overhead, because the +# order and content of these events are hard-coded. The user-defined topic is uniquely +# identified by a string, because we cannot guarantee the order in which each process +# registers the same topic. +# +# There are four types of private events: +# 1. "CLINET_REGISTER_STUB": Responsible for the connect. +# 2. "CUSTOM_FUNCRION_RPC": Responsible for RPC which using provided RPC methods +# The remote function must be given with the positional parameter "custom_method". +# "custom_method" must be picklable, otherwise use subscribe() to register topic +# and corresponding method on the client side in advance. +# 3. "NOTIFY_SHUTDOWN": Responsible for the disconnect info from other process. +# 4. "REQUIRE_SHUTDOWN": Responsible for the disconnect request which was asked for. +class RPCEvent(int, Enum): + CLINET_REGISTER_STUB = 1 + CUSTOM_FUNCRION_RPC = 2 + NOTIFY_SHUTDOWN = 3 + REQUIRE_SHUTDOWN = 4 + + +class DeviceMap: + + def __init__( + self, + our_name: str, + peer_name_list: List[str] = None, + our_device_list: List[int] = None, + peer_device_list: List[int] = None + ) -> None: + """ + Overview: + Mapping management for gpu devices. + Arguments: + - peer_name_list (List[str], optional): remote processes unique rpc name. + - our_device_list (List[int], optional): local processes device rank. + - peer_device_list (List[int], optional): remote processes device rank. + """ + + self.peer_name_list = peer_name_list or [] + self.our_device_list = our_device_list or [] + self.peer_device_list = peer_device_list or [] + + assert len(self.peer_name_list) == len(self.peer_name_list) + assert len(self.peer_device_list) == len(self.peer_device_list) + + self.our_name = str(our_name) + + def __str__(self): + info = "" + for i in range(len(self.peer_name_list)): + info += "{} : GPU-{} --> {} : GPU-{};{}".format( + self.our_name, str(self.our_device_list[i]), str(self.peer_name_list[i]), str(self.peer_device_list[i]), + "\n" if i != len(self.peer_name_list) - 1 else "" + ) + return info + + def set_device(self, option) -> None: + """ + Overview: + Initialize TensorPipeRpcBackendOptions according to the GPU mapping + set by the user. + Arguments: + - option (class TensorPipeRpcBackendOptions) + """ + for i in range(len(self.peer_name_list)): + option.set_device_map(self.peer_name_list[i], {self.our_device_list[i]: self.peer_device_list[i]}) + + +@MQ_REGISTRY.register("torchrpc") +class TORCHRPCMQ(MQ): + + def __init__( + self, + rpc_name: str, + init_method: str, + remote_parallel_entrance: Callable, + global_rank: int = 0, + attach_to: Optional[List[str]] = None, + device_maps: Optional[DeviceMap] = None, + async_rpc: Optional[bool] = True, + async_backend_polling: Optional[bool] = False, + use_cuda: Optional[bool] = False, + cuda_device: Optional[int] = None, + channels: Optional[List[str]] = None, + **kwargs + ) -> None: + """ + Overview: + Connect distributed processes with torch.distributed.rpc + Arguments: + - rpc_name (str): Globally unique name for rpc + - init_method (str): URL specifying how to initialize the process group. + - remote_parallel_entrance (Callable): Get the entry function of the remote Parallel() + struct. This function must ensure that the remote method call can find the corresponding + TORCHRPCMQ struct locally. + - attach_to (Optional[List[str]], optional): The ranks want to connect to, comma-separated ranks. + - global_rank (int, optional): Globally unique id. + - device_maps (DeviceMap, optional): Used for torch rpc init device_maps. + - async_rpc (Optional[bool]): Whether to use asynchronous rpc, the default is false. + - async_backend_polling (Optional[bool]): Whether to enable background threads to poll future objects + generated by asynchronous RPCs. + - use_cuda (Optional[bool]): Whether there will be data on the GPU side involved in the communication, + if true, torchrpc will set the device map. + - cuda_device (Optional[int]): An optional list of local devices, the default is all visible devices. + - channels (Optional[List[str]]): Channels contain the communication methods used by tensorpipe when + transmitting tensor, including the following possible values: "basic", "cma", "mpt_uv", "cuda_ipc", + "cuda_gdr", "cuda_xth". + """ + self.name = rpc_name + self.global_rank = global_rank + + self._running = False + self.remote_parallel_entrance = remote_parallel_entrance + + self._peer_set = set(attach_to if attach_to else []) + self._peer_set_lock = LockContext(type_=LockContextType.THREAD_LOCK) + + if platform.system().lower() != 'windows': + self.rpc_backend_options = TensorPipeRpcBackendOptions( + num_worker_threads=16, rpc_timeout=30, init_method=init_method, _channels=channels + ) + else: + raise WindowsError("TensorPipe does not support Windows yet!") + + if use_cuda: + assert torch.cuda.is_available() + assert device_maps + assert cuda_device is not None + + self._device_maps = device_maps + self._device_maps.set_device(self.rpc_backend_options) + self.rpc_backend_options.set_devices([cuda_device]) + else: + self._device_maps = None + + self._rpc_events = { + RPCEvent.CLINET_REGISTER_STUB: self.accept_rpc_connect, + RPCEvent.CUSTOM_FUNCRION_RPC: self.call_custom_rpc_method, + RPCEvent.NOTIFY_SHUTDOWN: self.notify_shutdown, + RPCEvent.REQUIRE_SHUTDOWN: self.stop + } + + self._async = async_rpc + self._async_backend_polling = async_rpc and async_backend_polling + if self._async_backend_polling: + self.async_future_queue = queue.Queue() + # Using threads to poll performance suffers due to the presence of Python GIL locks. + self.polling_thread = Thread(target=self._backend_polling, name="backend_polling", daemon=True) + + logging.debug( + "Torchrpc info: process name:\"{}\", node_id:[{}], attach_to[{}], init_method:{}.".format( + self.name, self.global_rank, self._peer_set, init_method + ) + ) + + def show_device_maps(self): + if self._device_maps: + logging.info("{}".format(self._device_maps)) + else: + logging.info("Not set device map!") + + def subscribe(self, topic: Union[int, str], fn: Optional[Callable] = None, is_once: Optional[bool] = False) -> None: + if fn is None: + raise RuntimeError("The Torchrpc subscription topic must be provided with a callback function.") + if topic not in self._rpc_events: + + def once_callback(*args, **kwargs): + fn(*args, **kwargs) + self.unsubscribe(topic) + + self._rpc_events[topic] = fn if not is_once else once_callback + + def unsubscribe(self, topic: Union[int, str]) -> None: + if topic in self._rpc_events: + self._rpc_events.pop(topic) + + def rpc_event_router(self, topic: Union[int, str], *args, **kwargs) -> Any: + """ + Overview: + Entry function called after all remote methods reach the target process. + Arguments: + - topic (Union[int, str]): Recevied topic. + """ + if topic not in self._rpc_events: + logging.warning("{} Torchrpc topic \"{}\" is not registered.".format(self.name, topic)) + return + + return (self._rpc_events[topic])(*args, **kwargs) + + def listen(self) -> None: + # If device_map is not specified, init_rpc will block until all processes + # smaller than the current rank call init_rpc. If device_map is specified, + # then init_rpc blocks until all processes present in device_map call init_rpc. + rpc.init_rpc(name=self.name, rank=self.global_rank, rpc_backend_options=self.rpc_backend_options) + + # Wait for all processes rendezvous before starting subsequent steps + for i, peer in enumerate(self._peer_set): + while True: + try: + self._do_rpc(peer, RPCEvent.CLINET_REGISTER_STUB, self.name, self.global_rank) + except Exception as e: + logging.debug( + "\"{}\" try to rendezvous with \"{}\" error, because \"{}\"".format(self.name, peer, e) + ) + time.sleep(0.5) + continue + else: + logging.debug("\"{}\" irendezvous with \"{}\" success!".format(self.name, peer)) + break + + if self._async_backend_polling: + self.polling_thread.start() + self._running = True + + logging.debug("\"{}\" Torchrpc backend init success.".format(self.name)) + + def publish(self, topic: Union[int, str], *args, **kwargs) -> Any: + if self._running: + timeout_list = [] + + if len(self._peer_set) == 0: + logging.warning("No peer available to communicate with") + return + + with self._peer_set_lock: + for peer in self._peer_set: + if not self._running: + break + try: + re = self._do_rpc(peer, topic, *args, **kwargs) + except RuntimeError as e: + logging.error("Publish topic \"{}\" to peer \"{}\" has error: \"{}\"!".format(topic, peer, e)) + timeout_list.append(peer) + + for timeout_peer in timeout_list: + self._peer_set.remove(timeout_peer) + + return re + + def accept_rpc_connect(self, peer_name: str, peer_rank: int) -> None: + """ + Overview: + Receive the link signal sent by the peer. + Arguments: + - peer_name (str) + - peer_rank (int): + """ + with self._peer_set_lock: + if peer_name not in self._peer_set: + self._peer_set.add(peer_name) + + return + + def call_custom_rpc_method(self, *args, **kwargs) -> Any: + """ + Overview: + If the upper-level module wants to pass in a custom rpc method, + it will be called remotly by this function. + """ + fn = kwargs.pop('custom_method') + return fn(*args, **kwargs) + + def notify_shutdown(self, peer_name: str, *args, **kwargs) -> None: + """ + Overview: + Receive the exit signal sent by the peer. + Arguments: + - peer_name (str) + """ + with self._peer_set_lock: + if peer_name in self._peer_set: + logging.info("\"{}\" recv shutdown info from \"{}\".".format(self.name, peer_name)) + self._peer_set.remove(peer_name) + + def recv(self): + raise NotImplementedError + + def stop(self) -> None: + if self._running: + with self._peer_set_lock: + for peer in self._peer_set: + try: + self._do_rpc(peer, RPCEvent.NOTIFY_SHUTDOWN, self.name) + except RuntimeError as e: + continue + + if self._async_backend_polling: + while self.async_future_queue.qsize() > 0: + time.sleep(0.05) + continue + + self.polling_thread.join(timeout=1) + self.polling_thread = None + + self._running = False + + # Set graceful=False, we do not wait for other RPC processes to reach this method. + rpc.shutdown(graceful=False) + + logging.info("\"{}\" Torchrpc backend is stopped.".format(self.name)) + + def require_to_shutdown(self, peer_name: str): + """ + Overview: + Request the remote torch rpc message queue to be stopped. + Arguments: + - peer_name (str): Remote torch rpc message's name + """ + try: + re = self._do_rpc(peer_name, RPCEvent.REQUIRE_SHUTDOWN) + if self._async and not self._async_backend_polling: + re.wait() + except RuntimeError as e: + logging.warning("Torchrpc polling_thread error: \"{}\".".format(e)) + + def _rendezvous_until_world_size(self, world_size) -> None: + while True: + all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() + if len(all_worker_info) != world_size: + time.sleep(0.5) + else: + break + + def _do_rpc(self, peer: str, topic: Union[int, str] = Optional[None], *arg, **kwargs) -> Union[None, Any]: + """ + Overview: + Where the actual RPC communication takes place + Arguments: + - peer (str): [The rpc name of the peer] + - topic (int): [The topic passed by upstream] + """ + arg = [topic] + list(arg) + + if self._async: + future = rpc.rpc_async(peer, self.remote_parallel_entrance, args=arg, kwargs=kwargs) + if not self._async_backend_polling: + return future + else: + self.async_future_queue.put(future) + return None + else: + return rpc.rpc_sync(peer, self.remote_parallel_entrance, args=arg, kwargs=kwargs) + + def _backend_polling(self) -> None: + while True: + if not self._running: + break + + future = self.async_future_queue.get() + try: + if not future.done(): + time.sleep(0.05) + future.wait() + except RuntimeError as e: + logging.warning("Torchrpc polling thread catch RuntimeError: \"{}\".".format(e)) + + def wait_for_shutdown(self): + """ + Overview: + The thread calling this method will block until mq receives a request for shutdown. + """ + while True: + if not self._running: + break + else: + time.sleep(0.5) diff --git a/ding/framework/middleware/distributer.py b/ding/framework/middleware/distributer.py index c68a4b808f..8f53068138 100644 --- a/ding/framework/middleware/distributer.py +++ b/ding/framework/middleware/distributer.py @@ -2,8 +2,10 @@ from dataclasses import fields from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union from ditk import logging -from ding.framework import task +from ding.framework import task, MQType from ding.data import StorageLoader, Storage, ModelLoader +from ding.utils import LockContext, LockContextType + if TYPE_CHECKING: from ding.framework.context import Context from torch.nn import Module @@ -11,7 +13,11 @@ class ContextExchanger: - def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] = None) -> None: + def __init__( + self, + skip_n_iter: int = 1, + storage_loader: Optional[StorageLoader] = None, + ) -> None: """ Overview: Exchange context between processes, @@ -33,9 +39,16 @@ def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] self._event_name = "context_exchanger_{role}" self._skip_n_iter = skip_n_iter self._storage_loader = storage_loader + + # Both nng and torchrpc use background threads to trigger the receiver's recv action, + # there is a race condition between sender and sender, and between senders and receiver. + self._put_lock = LockContext(LockContextType.THREAD_LOCK) + self._recv_ready = False + self._bypass_eventloop = task.router.mq_type == MQType.RPC + for role in task.role: # Only subscribe to other roles if not task.has_role(role): - task.on(self._event_name.format(role=role), self.put) + task.on(self._event_name.format(role=role), self.put, bypass_eventloop=self._bypass_eventloop) if storage_loader: task.once("finish", lambda _: storage_loader.shutdown()) @@ -62,7 +75,12 @@ def __call__(self, ctx: "Context"): if self._storage_loader and task.has_role(task.role.COLLECTOR): payload = self._storage_loader.save(payload) for role in task.roles: - task.emit(self._event_name.format(role=role), payload, only_remote=True) + task.emit( + self._event_name.format(role=role), + payload, + only_remote=True, + bypass_eventloop=self._bypass_eventloop + ) def __del__(self): if self._storage_loader: @@ -76,12 +94,14 @@ def put(self, payload: Union[Dict, Storage]): """ def callback(payload: Dict): - for key, item in payload.items(): - fn_name = "_put_{}".format(key) - if hasattr(self, fn_name): - getattr(self, fn_name)(item) - else: - logging.warning("Receive unexpected key ({}) in context exchanger".format(key)) + with self._put_lock: + for key, item in payload.items(): + fn_name = "_put_{}".format(key) + if hasattr(self, fn_name): + getattr(self, fn_name)(item) + else: + logging.warning("Receive unexpected key ({}) in context exchanger".format(key)) + self._recv_ready = True if isinstance(payload, Storage): assert self._storage_loader is not None, "Storage loader is not defined when data is a storage object." @@ -106,26 +126,29 @@ def fetch(self, ctx: "Context") -> Dict[str, Any]: return payload def merge(self, ctx: "Context"): + if task.has_role(task.role.LEARNER): # Learner should always wait for trajs. # TODO: Automaticlly wait based on properties, not roles. - while len(self._state) == 0: + while self._recv_ready is False: sleep(0.01) elif ctx.total_step >= self._skip_n_iter: start = time() - while len(self._state) == 0: + while self._recv_ready is False: if time() - start > 60: logging.warning("Timeout when waiting for new context! Node id: {}".format(task.router.node_id)) break sleep(0.01) - for k, v in self._state.items(): - if not task.has_role(task.role.COLLECTOR) and k.startswith('increment_'): - pure_k = k.split('increment_')[-1] - setattr(ctx, pure_k, getattr(ctx, pure_k) + v) - else: - setattr(ctx, k, v) - self._state = {} + with self._put_lock: + for k, v in self._state.items(): + if not task.has_role(task.role.COLLECTOR) and k.startswith('increment_'): + pure_k = k.split('increment_')[-1] + setattr(ctx, pure_k, getattr(ctx, pure_k) + v) + else: + setattr(ctx, k, v) + self._state = {} + self._recv_ready = False # Handle each attibute of context def _put_trajectories(self, traj: List[Any]): @@ -150,14 +173,14 @@ def _fetch_episodes(self, episodes: List[Any]): if task.has_role(task.role.COLLECTOR): return episodes - def _put_trajectory_end_idx(self, trajectory_end_idx: List[str]): + def _put_trajectory_end_idx(self, trajectory_end_idx: List[int]): if not task.has_role(task.role.LEARNER): return if "trajectory_end_idx" not in self._state: self._state["trajectory_end_idx"] = [] self._state["trajectory_end_idx"].extend(trajectory_end_idx) - def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[str]): + def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[int]): if task.has_role(task.role.COLLECTOR): return trajectory_end_idx @@ -179,12 +202,6 @@ def _put_env_episode(self, increment_env_episode: int): self._state['increment_env_episode'] = 0 self._state["increment_env_episode"] += increment_env_episode - def _fetch_env_episode(self, env_episode: int): - if task.has_role(task.role.COLLECTOR): - increment_env_episode = env_episode - self._local_state['env_episode'] - self._local_state['env_episode'] = env_episode - return increment_env_episode - def _put_train_iter(self, train_iter: int): if not task.has_role(task.role.LEARNER): self._state["train_iter"] = train_iter @@ -211,8 +228,9 @@ def __init__(self, model: "Module", model_loader: Optional[ModelLoader] = None) self._event_name = "model_exchanger" self._state_dict_cache: Optional[Union[object, Storage]] = None self._is_learner = task.has_role(task.role.LEARNER) + self._bypass_eventloop = task.router.mq_type == MQType.RPC if not self._is_learner: - task.on(self._event_name, self._cache_state_dict) + task.on(self._event_name, self._cache_state_dict, bypass_eventloop=self._bypass_eventloop) if model_loader: task.once("finish", lambda _: model_loader.shutdown()) @@ -278,11 +296,13 @@ def _send_model(self): if self._model_loader: self._model_loader.save(self._send_callback) else: - task.emit(self._event_name, self._model.state_dict(), only_remote=True) + task.emit( + self._event_name, self._model.state_dict(), only_remote=True, bypass_eventloop=self._bypass_eventloop + ) def _send_callback(self, storage: Storage): if task.running: - task.emit(self._event_name, storage, only_remote=True) + task.emit(self._event_name, storage, only_remote=True, bypass_eventloop=self._bypass_eventloop) def __del__(self): if self._model_loader: diff --git a/ding/framework/middleware/functional/collector.py b/ding/framework/middleware/functional/collector.py index 20820d7d00..16930db826 100644 --- a/ding/framework/middleware/functional/collector.py +++ b/ding/framework/middleware/functional/collector.py @@ -5,6 +5,7 @@ from ding.envs import BaseEnvManager from ding.policy import Policy from ding.torch_utils import to_ndarray, get_shape0 +from ding.torch_utils import to_device if TYPE_CHECKING: from ding.framework import OnlineRLContext @@ -98,6 +99,10 @@ def rolloutor(policy: Policy, env: BaseEnvManager, transitions: TransitionList) env_episode_id = [_ for _ in range(env.env_num)] current_id = env.env_num + use_cuda_shared_memory = False + + if hasattr(cfg, "env") and hasattr(cfg.env, "manager"): + use_cuda_shared_memory = cfg.env.manager.cuda_shared_memory def _rollout(ctx: "OnlineRLContext"): """ @@ -113,16 +118,30 @@ def _rollout(ctx: "OnlineRLContext"): trajectory stops. """ - nonlocal current_id + nonlocal current_id, use_cuda_shared_memory timesteps = env.step(ctx.action) ctx.env_step += len(timesteps) - timesteps = [t.tensor() for t in timesteps] + + if not use_cuda_shared_memory: + timesteps = [t.tensor() for t in timesteps] + # TODO abnormal env step for i, timestep in enumerate(timesteps): transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep) transition = ttorch.as_tensor(transition) # TBD transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]]) + + # torchrpc currently uses "cuda:0" as the transmission device by default, + # so all data on the cpu side is copied to "cuda:0" here. In fact this + # copy is unnecessary, because torchrpc can support both cpu side and gpu + # side data to communicate using RDMA, but mixing the two transfer types + # will cause a bug, see issue: + # Because we have copied the large payload "obs" and "next_obs" from the + # collector's subprocess to "cuda:0" in advance, the copy operation here + # will not have too much overhead. + if use_cuda_shared_memory: + transition = to_device(transition, "cuda:0") transitions.append(timestep.env_id, transition) if timestep.done: policy.reset([timestep.env_id]) diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index 38e343e495..e8369a4476 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -8,20 +8,31 @@ from ditk import logging import tempfile import socket +import enum from os import path -from typing import Callable, Dict, List, Optional, Tuple, Union, Set +from typing import Callable, Dict, List, Optional, Tuple, Union, Set, Any from threading import Thread from ding.framework.event_loop import EventLoop from ding.utils.design_helper import SingletonMetaclass from ding.framework.message_queue import * from ding.utils.registry_factory import MQ_REGISTRY +from easydict import EasyDict +from ding.framework.message_queue.torch_rpc import DeviceMap, DEFAULT_DEVICE_MAP_NUMS # Avoid ipc address conflict, random should always use random seed random = random.Random() +class MQType(int, enum.Enum): + NNG = 0 + REDIS = 1 + RPC = 2 + + class Parallel(metaclass=SingletonMetaclass): + _MQtype_dict = {"nng": MQType.NNG, "redis": MQType.REDIS, "torchrpc": MQType.RPC} + def __init__(self) -> None: # Init will only be called once in a process self._listener = None @@ -29,7 +40,6 @@ def __init__(self) -> None: self.node_id = None self.local_id = None self.labels = set() - self._event_loop = EventLoop("parallel_{}".format(id(self))) self._retries = 0 # Retries in auto recovery def _run( @@ -52,9 +62,18 @@ def _run( self.auto_recover = auto_recover self.max_retries = max_retries self._mq = MQ_REGISTRY.get(mq_type)(**kwargs) + self.mq_type = self._MQtype_dict[mq_type] + + if self.mq_type != MQType.RPC: + self._event_loop = EventLoop("parallel_{}".format(id(self))) + time.sleep(self.local_id * self.startup_interval) - self._listener = Thread(target=self.listen, name="mq_listener", daemon=True) - self._listener.start() + if self.mq_type == MQType.RPC: + self._mq.listen() + self.rpc_name = self._mq.name + else: + self._listener = Thread(target=self.listen, name="mq_listener", daemon=True) + self._listener.start() @classmethod def runner( @@ -72,7 +91,11 @@ def runner( max_retries: int = float("inf"), redis_host: Optional[str] = None, redis_port: Optional[int] = None, - startup_interval: int = 1 + init_method: Optional[str] = "env://", + startup_interval: int = 1, + use_cuda: Optional[bool] = False, + local_cuda_devices: Optional[List[str]] = None, + cuda_device_map: Optional[List[str]] = None ) -> Callable: """ Overview: @@ -100,7 +123,11 @@ def runner( """ all_args = locals() del all_args["cls"] - args_parsers = {"nng": cls._nng_args_parser, "redis": cls._redis_args_parser} + args_parsers = { + MQType.NNG: cls._nng_args_parser, + MQType.REDIS: cls._redis_args_parser, + MQType.RPC: cls._torchrpc_args_parser + } assert n_parallel_workers > 0, "Parallel worker number should bigger than 0" @@ -111,7 +138,7 @@ def _runner(main_process: Callable, *args, **kwargs) -> None: Arguments: - main_process (:obj:`Callable`): The main function, your program start from here. """ - runner_params = args_parsers[mq_type](**all_args) + runner_params = args_parsers[cls._MQtype_dict[mq_type]](**all_args) params_group = [] for i, runner_kwargs in enumerate(runner_params): runner_kwargs["local_id"] = i @@ -297,8 +324,13 @@ def on(self, event: str, fn: Callable) -> None: - fn (:obj:`Callable`): Function body. """ if self.is_active: - self._mq.subscribe(event) - self._event_loop.on(event, fn) + if self.mq_type == MQType.RPC: + self._mq.subscribe(event, fn) + else: + self._mq.subscribe(event) + + if hasattr(self, "_event_loop"): + self._event_loop.on(event, fn) def once(self, event: str, fn: Callable) -> None: """ @@ -310,8 +342,13 @@ def once(self, event: str, fn: Callable) -> None: - fn (:obj:`Callable`): Function body. """ if self.is_active: - self._mq.subscribe(event) - self._event_loop.once(event, fn) + if self.mq_type == MQType.RPC: + self._mq.subscribe(event, fn, True) + else: + self._mq.subscribe(event) + + if hasattr(self, "_event_loop"): + self._event_loop.once(event, fn) def off(self, event: str) -> None: """ @@ -322,7 +359,9 @@ def off(self, event: str) -> None: """ if self.is_active: self._mq.unsubscribe(event) - self._event_loop.off(event) + + if hasattr(self, "_event_loop"): + self._event_loop.off(event) def emit(self, event: str, *args, **kwargs) -> None: """ @@ -332,13 +371,16 @@ def emit(self, event: str, *args, **kwargs) -> None: - event (:obj:`str`): Event name. """ if self.is_active: - payload = {"a": args, "k": kwargs} - try: - data = pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL) - except AttributeError as e: - logging.error("Arguments are not pickable! Event: {}, Args: {}".format(event, args)) - raise e - self._mq.publish(event, data) + if self.mq_type == MQType.RPC: + self._mq.publish(event, *args, **kwargs) + else: + payload = {"a": args, "k": kwargs} + try: + data = pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL) + except AttributeError as e: + logging.error("Arguments are not pickable! Event: {}, Args: {}".format(event, args)) + raise e + self._mq.publish(event, data) def _handle_message(self, topic: str, msg: bytes) -> None: """ @@ -349,6 +391,8 @@ def _handle_message(self, topic: str, msg: bytes) -> None: - msg (:obj:`bytes`): Recevied message. """ event = topic + assert hasattr(self, "_event_loop") and self._event_loop + if not self._event_loop.listened(event): logging.debug("Event {} was not listened in parallel {}".format(event, self.node_id)) return @@ -382,10 +426,182 @@ def stop(self): logging.info("Stopping parallel worker on node: {}".format(self.node_id)) self.is_active = False time.sleep(0.03) - if self._mq: + if hasattr(self, "_mq") and self._mq: self._mq.stop() self._mq = None if self._listener: self._listener.join(timeout=1) self._listener = None - self._event_loop.stop() + if hasattr(self, "_event_loop"): + self._event_loop.stop() + + @classmethod + def make_device_maps(cls, + self_id: int, + local_cuda_device: int, + cuda_device_map: Optional[List[str]] = None) -> List[DeviceMap]: + dmap = DeviceMap(f"Node_{self_id}") + if cuda_device_map: + # If the user gave a custom device map, use it. + for item in cuda_device_map: + remote_node_id, local_device_rank, remote_device_rank = item.split("_") + dmap.peer_name_list.append(f"Node_{remote_node_id}") + dmap.our_device_list.append(int(local_device_rank)) + dmap.peer_device_list.append(int(remote_device_rank)) + else: + assert self_id < DEFAULT_DEVICE_MAP_NUMS + # If the user does not provide deivce_map and specifies the use of GPU, we default + # each process to use GPU:0 for communication. This is a convenient approach in a + # container environment. + for i in range(DEFAULT_DEVICE_MAP_NUMS): + if i == self_id: + continue + dmap.peer_name_list.append(f"Node_{i}") + dmap.our_device_list.append(local_cuda_device) + dmap.peer_device_list.append(0) + + return dmap + + @classmethod + def _torchrpc_args_parser( + cls, + n_parallel_workers: int, + attach_to: Optional[List[str]] = None, + node_ids: Optional[Union[List[int], int]] = None, + init_method: Optional[str] = "env://", + use_cuda: Optional[bool] = False, + local_cuda_devices: Optional[List[str]] = None, + cuda_device_map: Optional[List[str]] = None, + remote_parallel_entrance: Optional[Callable] = None, + async_rpc: Optional[bool] = True, + async_backend_polling: Optional[bool] = False, + channels: Optional[List[str]] = None, + **kwargs + ) -> List[Dict[str, dict]]: + import torch + assert init_method + + attach_to = attach_to or [] + node_divice_dict = dict() + + if local_cuda_devices or cuda_device_map: + use_cuda = True + if local_cuda_devices and not cuda_device_map: + logging.warning( + '''If you set local_cuda_devices but not cuda_device_map, torchrpc will use the default + device mapping to map all local GPU devices to the peer GPU-0.''' + ) + + # From the unique identification of each process when using torchrpc to communicate. + local_process_ids = cls.padding_param(node_ids, n_parallel_workers, 0) + attach_to = [f"Node_{id}" for id in attach_to] + nodes = ["Node_{}".format(id) for id in local_process_ids] + + try: + # torchrpc uses "node_id" as global rank, perform necessary checks here. + assert local_process_ids + assert len(local_process_ids) == n_parallel_workers + assert len(set(local_process_ids)) == n_parallel_workers + except AssertionError as e: + raise RuntimeError( + '''Arg "node_ids" must be specified. Please set the number of "node_ids" to be the same as + "n_parallel_workers" (Hint: "node_id" is the unique identifier between processes)''' + ) + + if use_cuda: + assert torch.cuda.is_available() + if local_cuda_devices: + if len(local_cuda_devices) != n_parallel_workers: + raise RuntimeError( + "The length of the \"local_cuda_devices\":[\"{}\"] is != \"n_parallel_workers\":[\"{}\"]". + format(len(local_cuda_devices), n_parallel_workers) + ) + local_cuda_devices = [int(i) for i in local_cuda_devices] + else: + gpu_nums = torch.cuda.device_count() + if n_parallel_workers > gpu_nums: + raise RuntimeError( + "The number of available GPUS [\"{}\"] is less than n_parallel_workers[\"{}\"]".format( + gpu_nums, n_parallel_workers + ) + ) + local_cuda_devices = cls.padding_param(0, n_parallel_workers, 0) + + dmap_lists = [ + cls.make_device_maps(node_id, local_cuda_devices[i], cuda_device_map) + for i, node_id in enumerate(local_process_ids) + ] + else: + local_cuda_devices = [None for i in range(n_parallel_workers)] + dmap_lists = [None for i in range(n_parallel_workers)] + + if channels: + list_channels = [channels for i in range(n_parallel_workers)] + else: + list_channels = [None for i in range(n_parallel_workers)] + + global local_parallel_entrance + entrance_fn = remote_parallel_entrance if remote_parallel_entrance else local_parallel_entrance + runner_params = [] + for i in range(n_parallel_workers): + runner_kwargs = { + **kwargs, "node_id": local_process_ids[i], + "n_parallel_workers": n_parallel_workers, + "rpc_name": nodes[i], + "global_rank": local_process_ids[i], + "init_method": init_method, + "remote_parallel_entrance": entrance_fn, + "attach_to": attach_to, + "device_maps": dmap_lists[i], + "cuda_device": local_cuda_devices[i], + "use_cuda": use_cuda, + "async_rpc": async_rpc, + "async_backend_polling": async_backend_polling, + "channels": list_channels[i] + } + runner_params.append(runner_kwargs) + + return runner_params + + def get_mq(self): + return self._mq + + def judge_use_cuda_shm(self, cfg: EasyDict) -> None: + """ + Overview: + Only when torchrpc is used and env uses shared memory, cuda tensor + is used as the communication method between env subprocesses and + collector process. + Arguments: + - cfg (:obj:`EasyDict`): Input config dict which is to be used in the following pipeline. + """ + if not hasattr(cfg, "env") or not hasattr(cfg.env, "manager"): + return + + if cfg.env.manager.shared_memory: + if self.mq_type == MQType.RPC and "collector" in self.labels: + cfg.env.manager.cuda_shared_memory = True + return + cfg.env.manager.cuda_shared_memory = False + return + + +def local_parallel_entrance(topic: Union[int, str], *args, **kwargs) -> Any: + """ + Overview: + We must provide a method for all RPC methods to obtain the data structure + instantiated in the remote process. Because we don't want to and can't pickle + data structures such as Task() or Parallel(). + + Unlike nng, torchrpc needs to consider thread safety. Class 'Parallel' is a singleton + class. At this moment, Parallel() must have been instantiated, because + 'accept_rpc_connect'will only be executed after local-side init_rpc has completed, + + This function must be picklable, so should not be a local function. + + This function will be called concurrently by multiple threads, and the provider of + the RPC method needs to ensure that its own RPC method is thread-safe. + Arguments: + - topic (Union[int, str]): Recevied topic. + """ + return Parallel().get_mq().rpc_event_router(topic, *args, **kwargs) diff --git a/ding/framework/task.py b/ding/framework/task.py index ae6e0e256d..131349bb40 100644 --- a/ding/framework/task.py +++ b/ding/framework/task.py @@ -13,7 +13,7 @@ import inspect from ding.framework.context import Context -from ding.framework.parallel import Parallel +from ding.framework.parallel import Parallel, MQType from ding.framework.event_loop import EventLoop from functools import wraps @@ -201,6 +201,7 @@ def run(self, max_step: int = int(1e12)) -> None: assert self._running, "Please make sure the task is running before calling the this method, see the task.start" if len(self._middleware) == 0: return + start_time = 0 for i in range(max_step): for fn in self._middleware: self.forward(fn) @@ -215,6 +216,11 @@ def run(self, max_step: int = int(1e12)) -> None: break self.renew() + if i == 0: + # Skip the first round of timing + start_time = time.time() + return start_time + def wrap(self, fn: Callable, lock: Union[bool, Lock] = False) -> Callable: """ Overview: @@ -424,7 +430,15 @@ def async_executor(self, fn: Callable, *args, **kwargs) -> None: t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs) self._async_stack.append(t) - def emit(self, event: str, *args, only_remote: bool = False, only_local: bool = False, **kwargs) -> None: + def emit( + self, + event: str, + *args, + only_remote: bool = False, + only_local: bool = False, + bypass_eventloop: bool = False, + **kwargs + ) -> None: """ Overview: Emit an event, call listeners. @@ -432,43 +446,66 @@ def emit(self, event: str, *args, only_remote: bool = False, only_local: bool = - event (:obj:`str`): Event name. - only_remote (:obj:`bool`): Only broadcast the event to the connected nodes, default is False. - only_local (:obj:`bool`): Only emit local event, default is False. + - bypass_eventloop (:obj:`bool`): Whether to select to bypass eventloop of Task() and Parallel(), + this parameter can only be True when torchrpc is used as the communication backend. If use torchrpc, + the invoked of the callback is triggered by the torchrpc's backend thread. - args (:obj:`any`): Rest arguments for listeners. """ # Check if need to broadcast event to connected nodes, default is True assert self._running, "Please make sure the task is running before calling the this method, see the task.start" - if only_local: - self._event_loop.emit(event, *args, **kwargs) - elif only_remote: + if bypass_eventloop: if self.router.is_active: - self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs) + self.router.emit(self._wrap_event_name(event), *args, **kwargs) else: - if self.router.is_active: - self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs) + if only_local: + self._event_loop.emit(event, *args, **kwargs) + elif only_remote: + if self.router.is_active: + self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs) + else: + if self.router.is_active: + self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs) self._event_loop.emit(event, *args, **kwargs) - def on(self, event: str, fn: Callable) -> None: + def on(self, event: str, fn: Callable, bypass_eventloop: Optional[bool] = False) -> None: """ Overview: Subscribe to an event, execute this function every time the event is emitted. Arguments: - event (:obj:`str`): Event name. - fn (:obj:`Callable`): The function. + - bypass_eventloop (:obj:`bool`): Same as the bypass_eventloop arg in Task.emit. """ - self._event_loop.on(event, fn) - if self.router.is_active: - self.router.on(self._wrap_event_name(event), self._event_loop.emit) + if bypass_eventloop: + if self.router.mq_type == MQType.RPC: + if self.router.is_active: + self.router.on(self._wrap_event_name(event), fn) + else: + raise RuntimeError("Only message queue implemented by torchrpc allows bypass eventloop") + else: + self._event_loop.on(event, fn) + if self.router.is_active: + self.router.on(self._wrap_event_name(event), self._event_loop.emit) - def once(self, event: str, fn: Callable) -> None: + def once(self, event: str, fn: Callable, bypass_eventloop: Optional[bool] = False) -> None: """ Overview: Subscribe to an event, execute this function only once when the event is emitted. Arguments: - event (:obj:`str`): Event name. - fn (:obj:`Callable`): The function. + - bypass_eventloop (:obj:`bool`): Same as the bypass_eventloop arg in Task.emit. """ - self._event_loop.once(event, fn) - if self.router.is_active: - self.router.on(self._wrap_event_name(event), self._event_loop.emit) + if bypass_eventloop: + if self.router.mq_type == MQType.RPC: + if self.router.is_active: + self.router.once(self._wrap_event_name(event), fn) + else: + raise RuntimeError("Only message queue implemented by torchrpc allows bypass eventloop") + else: + self._event_loop.once(event, fn) + if self.router.is_active: + self.router.on(self._wrap_event_name(event), self._event_loop.emit) def off(self, event: str, fn: Optional[Callable] = None) -> None: """ diff --git a/ding/torch_utils/data_helper.py b/ding/torch_utils/data_helper.py index 8e6d026499..7b098698bd 100644 --- a/ding/torch_utils/data_helper.py +++ b/ding/torch_utils/data_helper.py @@ -62,6 +62,8 @@ def to_device(item: Any, device: str, ignore_keys: list = []) -> Any: return item elif isinstance(item, torch.distributions.Distribution): # for compatibility return item + elif isinstance(item, ttorch.Tensor): + return item.to(device) else: raise TypeError("not support item type: {}".format(type(item))) diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index 5e262b2c38..88f39e0d3d 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -11,7 +11,7 @@ from .k8s_helper import get_operator_server_kwargs, exist_operator_server, DEFAULT_K8S_COLLECTOR_PORT, \ DEFAULT_K8S_LEARNER_PORT, DEFAULT_K8S_AGGREGATOR_SLAVE_PORT, DEFAULT_K8S_COORDINATOR_PORT, pod_exec_command, \ K8sLauncher -from .lock_helper import LockContext, LockContextType, get_file_lock, get_rw_file_lock +from .lock_helper import LockContext, LockContextType, get_file_lock, get_rw_file_lock, synchronized from .log_helper import build_logger, pretty_print, LoggerFactory from .log_writer_helper import DistributedWriter from .orchestrator_launcher import OrchestratorLauncher @@ -36,3 +36,6 @@ else: from .pytorch_ddp_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \ allreduce, broadcast, DistContext, allreduce_async, synchronize + +from .comm_perf_helper import TENSOR_SIZE_LIST, DO_PERF, tensor_size_beauty_print, byte_beauty_print, \ + dtype_2_byte, time_perf_avg, time_perf_once, print_timer_result_csv diff --git a/ding/utils/comm_perf_helper.py b/ding/utils/comm_perf_helper.py new file mode 100644 index 0000000000..416b794c56 --- /dev/null +++ b/ding/utils/comm_perf_helper.py @@ -0,0 +1,145 @@ +import torch +import functools +import time +from concurrent import futures +from ditk import logging +from typing import List, Optional, Tuple, Dict, Any +from ding.utils import EasyTimer + +# Data size for some tests +UNIT_1_B = 1 +UNIT_1_KB = 1024 * UNIT_1_B +UNIT_1_MB = 1024 * UNIT_1_KB +UNIT_1_GB = 1024 * UNIT_1_MB +TENSOR_SIZE_LIST = [ + 8 * UNIT_1_B, 32 * UNIT_1_B, 64 * UNIT_1_B, UNIT_1_KB, 4 * UNIT_1_KB, 64 * UNIT_1_KB, 1 * UNIT_1_MB, 4 * UNIT_1_MB, + 64 * UNIT_1_MB, 512 * UNIT_1_MB, 1 * UNIT_1_GB, 2 * UNIT_1_GB, 4 * UNIT_1_GB +] + +# TODO: Add perf switch to avoid performance loss to critical paths during non-test time. +DO_PERF = False + +# Convert from torch.dtype to bytes +TYPE_MAP = {torch.float32: 4, torch.float64: 8, torch.int32: 4, torch.int64: 8, torch.uint8: 1} + +# A list of time units and names. +TIME_UNIT = [1, 1000, 1000] +TIME_NAME = ["s", "ms", "us"] + +# The global function timing result is stored in OUTPUT_DICT. +OUTPUT_DICT = dict() + + +def _store_timer_result(func_name: str, avg_tt: float): + if func_name not in OUTPUT_DICT.keys(): + OUTPUT_DICT[func_name] = str(round(avg_tt, 4)) + "," + else: + OUTPUT_DICT[func_name] = OUTPUT_DICT[func_name] + str(round(avg_tt, 4)) + "," + + +def print_timer_result_csv(): + """ + Overview: + Output the average execution time of all functions durning this + experiment in csv format. + """ + for key, value in OUTPUT_DICT.items(): + print("{},{}".format(key, value)) + + +def time_perf_once(unit: int, cuda: bool = False): + """ + Overview: + Decorator function to measure the time of a function execution. + Arguments: + - unit ([int]): 0 for s timer, 1 for ms timer, 2 for us timer. + - cuda (bool, optional): Whether CUDA operation occurred within the timing range. + """ + + def decorator(func): + + @functools.wraps(func) + def wrapper(*args, **kw): + timer = EasyTimer(cuda=cuda) + with timer: + func(*args, **kw) + tt = timer.value * TIME_UNIT[unit] + logging.info("func:\"{}\" use {:.4f} {},".format(func.__name__, tt, TIME_NAME[unit])) + + _store_timer_result(func.__name__, tt) + + return wrapper + + return decorator + + +def time_perf_avg(unit: int, count: int, skip_iter: int = 0, cuda: bool = False): + """ + Overview: + A decorator that averages the execution time of a function. + Arguments: + - unit (int): 0 for s timer, 1 for ms timer, 2 for us timer + - time_list (List): User-supplied list for staging execution times. + - count (int): Loop count. + - skip_iter (int, optional): Skip the first n iter times. + - cuda (bool, optional): Whether CUDA operation occurred within the timing range. + """ + time_list = [] + + if skip_iter >= count: + logging.error("skip_iter:[{}] must >= count:[{}]".format(skip_iter, count)) + return None + + def decorator(func): + + @functools.wraps(func) + def wrapper(idx, *args, **kw): + timer = EasyTimer(cuda=cuda) + with timer: + func(*args, **kw) + + if idx < skip_iter: + return + + time_list.append(timer.value * TIME_UNIT[unit]) + if idx == count - 1: + avg_tt = sum(time_list) / len(time_list) + logging.info( + "\"{}\": repeat[{}], avg_time[{:.4f}]{},".format( + func.__name__, len(time_list), avg_tt, TIME_NAME[unit] + ) + ) + + _store_timer_result(func.__name__, avg_tt) + time_list.clear() + + return wrapper + + return decorator + + +def dtype_2_byte(dtype: torch.dtype) -> int: + return TYPE_MAP[dtype] + + +def tensor_size_beauty_print(length: int, dtype: torch.dtype) -> tuple: + return byte_beauty_print(length * dtype_2_byte(dtype)) + + +def byte_beauty_print(nbytes: int) -> tuple: + """ + Overview: + Output the bytes in a human-readable format. + Arguments: + - nbytes (int): number of bytes. + + Returns: + tuple: tuple of formatted bytes and units. + """ + unit_dict = [("GB", 1024 * 1024 * 1024), ("MB", 1024 * 1024), ("KB", 1024), ("B", 1)] + + for item in unit_dict: + if nbytes // item[1] > 0: + return nbytes / item[1], item[0] + + return nbytes, "B" diff --git a/ding/utils/lock_helper.py b/ding/utils/lock_helper.py index cb4a9c13b5..02c31c2191 100644 --- a/ding/utils/lock_helper.py +++ b/ding/utils/lock_helper.py @@ -2,6 +2,7 @@ import multiprocessing import threading import platform +import functools from enum import Enum, unique from readerwriterlock import rwlock @@ -12,6 +13,19 @@ fcntl = None +class DummyLock: + """ + DummyLock can be used in codes where locks are not required. + Reduce unnecessary code. + """ + + def acquire(self): + pass + + def release(self): + pass + + @unique class LockContextType(Enum): """ @@ -19,11 +33,15 @@ class LockContextType(Enum): """ THREAD_LOCK = 1 PROCESS_LOCK = 2 + DUMMY_LOCK = 3 + CONDITION_LOCK = 4 _LOCK_TYPE_MAPPING = { LockContextType.THREAD_LOCK: threading.Lock, LockContextType.PROCESS_LOCK: multiprocessing.Lock, + LockContextType.DUMMY_LOCK: DummyLock, + LockContextType.CONDITION_LOCK: threading.Condition } @@ -118,3 +136,24 @@ def get_file_lock(name: str, op: str) -> None: except Exception as e: pass return FcntlContext(lock_name) + + +def synchronized(func): + """ + Overview: + thread lock decorator. + Arguments: + - func ([type]): A function that needs to be protected by a lock. + """ + func.__lock__ = threading.Lock() + + def decorator(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with func.__lock__: + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/dizoo/atari/example/atari_dqn_dist_ddp.py b/dizoo/atari/example/atari_dqn_dist_ddp.py index 6b615abb21..85b498d455 100644 --- a/dizoo/atari/example/atari_dqn_dist_ddp.py +++ b/dizoo/atari/example/atari_dqn_dist_ddp.py @@ -14,7 +14,6 @@ from dizoo.atari.envs.atari_env import AtariEnv from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config - logging.getLogger().setLevel(logging.INFO) main_config.exp_name = 'pong_dqn_seed0_ditask_dist_ddp' diff --git a/dizoo/atari/example/atari_dqn_dist_rdma.py b/dizoo/atari/example/atari_dqn_dist_rdma.py index 71fb1d64a1..e852108bda 100644 --- a/dizoo/atari/example/atari_dqn_dist_rdma.py +++ b/dizoo/atari/example/atari_dqn_dist_rdma.py @@ -8,15 +8,17 @@ from ding.framework import task, ding_init from ding.framework.context import OnlineRLContext from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ - eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \ - online_logger + eps_greedy_handler, CkptSaver, ContextExchanger, ModelExchanger, nstep_reward_enhancer, termination_checker from ding.utils import set_pkg_seed from dizoo.atari.envs.atari_env import AtariEnv from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config +from ding.utils import EasyTimer +import os +import time def main(): - logging.getLogger().setLevel(logging.INFO) + logger = logging.getLogger().setLevel(logging.DEBUG) main_config.exp_name = 'pong_dqn_seed0_dist_rdma' cfg = compile_config(main_config, create_cfg=create_config, auto=True) ding_init(cfg) @@ -26,46 +28,53 @@ def main(): set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) model = DQN(**cfg.policy.model) - policy = DQNPolicy(cfg.policy, model=model) + + # Consider the case with multiple processes + if task.router.is_active: + # You can use labels to distinguish between workers with different roles, + # here we use node_id to distinguish. + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + else: + task.add_role(task.role.COLLECTOR) + + logging.debug("label {}".format(task.router.labels)) + logging.debug("task role {}".format(task._roles)) if 'learner' in task.router.labels: + policy = DQNPolicy(cfg.policy, model=model) logging.info("Learner running on node {}".format(task.router.node_id)) buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) - task.use( - context_exchanger( - send_keys=["train_iter"], - recv_keys=["trajectories", "episodes", "env_step", "env_episode"], - skip_n_iter=0 - ) - ) - task.use(model_exchanger(model, is_learner=True)) + task.use(ContextExchanger(skip_n_iter=0)) + task.use(ModelExchanger(model)) task.use(nstep_reward_enhancer(cfg)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=1000)) elif 'collector' in task.router.labels: + policy = DQNPolicy(cfg.policy, model=model) logging.info("Collector running on node {}".format(task.router.node_id)) collector_cfg = deepcopy(cfg.env) collector_cfg.is_train = True + logging.info(cfg.env.manager) + logging.info(type(cfg.env.manager)) + # task.router.judge_use_cuda_shm(cfg) + logging.debug("cuda_shared_memory {}".format(cfg.env.manager.cuda_shared_memory)) collector_env = SubprocessEnvManagerV2( env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) - task.use( - context_exchanger( - send_keys=["trajectories", "episodes", "env_step", "env_episode"], - recv_keys=["train_iter"], - skip_n_iter=1 - ) - ) - task.use(model_exchanger(model, is_learner=False)) + task.use(ContextExchanger(skip_n_iter=1)) + task.use(ModelExchanger(model)) task.use(eps_greedy_handler(cfg)) task.use(StepCollector(cfg, policy.collect_mode, collector_env)) task.use(termination_checker(max_env_step=int(1e7))) else: raise KeyError("invalid router labels: {}".format(task.router.labels)) - task.run() + start_time = task.run(max_step=100) + end_time = time.time() + logging.debug("atari iter 99 use {:.4f} s,".format(end_time - start_time)) if __name__ == "__main__": diff --git a/pytest.ini b/pytest.ini index efdeaba023..25c1e374a9 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,5 +10,7 @@ markers = envpooltest other tmp + multiprocesstest + mqbenchmark norecursedirs = ding/hpc_rl/tests