diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index c7195d820b..c69e5fe0e6 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -11,12 +11,11 @@ jobs: if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: matrix: - python-version: [3.7, 3.8, 3.9] - + python-version: ["3.7", "3.8", "3.9"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: do_unittest @@ -41,12 +40,13 @@ jobs: if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: matrix: - python-version: [3.7, 3.8, 3.9] - + python-version: ["3.7", "3.8", "3.9"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 + env: + AGENT_TOOLSDIRECTORY: /opt/hostedtoolcache with: python-version: ${{ matrix.python-version }} - name: do_benchmark @@ -55,3 +55,70 @@ jobs: python -m pip install ".[test,k8s]" ./ding/scripts/install-k8s-tools.sh make benchmark + + test_multiprocess: + runs-on: self-hosted + if: "!contains(github.event.head_commit.message, 'ci skip')" + strategy: + matrix: + python-version: ["3.7", "3.8", "3.9"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: do_multiprocesstest + timeout-minutes: 40 + run: | + python -m pip install box2d-py + python -m pip install . + python -m pip install ".[test,k8s]" + ./ding/scripts/install-k8s-tools.sh + make multiprocesstest + + test_cuda: + runs-on: self-hosted + if: "!contains(github.event.head_commit.message, 'ci skip')" + strategy: + matrix: + python-version: ["3.7", "3.8", "3.9"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + env: + AGENT_TOOLSDIRECTORY: /opt/hostedtoolcache + with: + python-version: ${{ matrix.python-version }} + - name: do_unittest + timeout-minutes: 40 + run: | + python -m pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 + python -m pip install box2d-py + python -m pip install . + python -m pip install ".[test,k8s]" + ./ding/scripts/install-k8s-tools.sh + make cudatest + + test_mq_benchmark: + runs-on: self-hosted + if: "!contains(github.event.head_commit.message, 'ci skip')" + strategy: + matrix: + python-version: ["3.7", "3.8", "3.9"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + env: + AGENT_TOOLSDIRECTORY: /opt/hostedtoolcache + with: + python-version: ${{ matrix.python-version }} + - name: do_mqbenchmark + run: | + python -m pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 + python -m pip install . + python -m pip install ".[test,k8s]" + ./ding/scripts/install-k8s-tools.sh + make mqbenchmark \ No newline at end of file diff --git a/Makefile b/Makefile index 39810b7871..c6ead4d1ab 100644 --- a/Makefile +++ b/Makefile @@ -57,11 +57,25 @@ benchmark: --durations=0 \ -sv -m benchmark +multiprocesstest: + pytest ${TEST_DIR} \ + --cov-report=xml \ + --cov-report term-missing \ + --cov=${COV_DIR} \ + ${DURATIONS_COMMAND} \ + ${WORKERS_COMMAND} \ + -sv -m multiprocesstest + +mqbenchmark: + pytest ${TEST_DIR} \ + --durations=0 \ + -sv -m mqbenchmark + test: unittest # just for compatibility, can be changed later cpu_test: unittest algotest benchmark -all_test: unittest algotest cudatest benchmark +all_test: unittest algotest cudatest benchmark multiprocesstest format: yapf --in-place --recursive -p --verbose --style .style.yapf ${FORMAT_DIR} diff --git a/codecov.yml b/codecov.yml index 0779ada773..af3e5c97dd 100644 --- a/codecov.yml +++ b/codecov.yml @@ -6,3 +6,10 @@ coverage: target: auto threshold: 0.5% if_ci_failed: success #success, failure, error, ignore + +# fix me +# The unittests of the torchrpc module are tested by different runners and cannot be included +# in the test_unittest's coverage report. To keep CI happy, we don't count torchrpc related coverage. +ignore: + - /mnt/cache/wangguoteng/DI-engine/ding/framework/message_queue/torch_rpc.py + - /mnt/cache/wangguoteng/DI-engine/ding/framework/message_queue/perfs/* diff --git a/ding/compatibility.py b/ding/compatibility.py index dd6b1fd0da..94d37991e0 100644 --- a/ding/compatibility.py +++ b/ding/compatibility.py @@ -7,3 +7,7 @@ def torch_ge_131(): def torch_ge_180(): return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 180 + + +def torch_ge_1121(): + return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 1121 diff --git a/ding/data/shm_buffer.py b/ding/data/shm_buffer.py index b76f5d56e9..875a7210c7 100644 --- a/ding/data/shm_buffer.py +++ b/ding/data/shm_buffer.py @@ -3,6 +3,10 @@ import ctypes import numpy as np import torch +import torch.multiprocessing as mp +from functools import reduce +from ditk import logging +from abc import abstractmethod _NTYPE_TO_CTYPE = { np.bool_: ctypes.c_bool, @@ -18,8 +22,37 @@ np.float64: ctypes.c_double, } +# uint16, uint32, uint32 +_NTYPE_TO_TTYPE = { + np.bool_: torch.bool, + np.uint8: torch.uint8, + # np.uint16: torch.int16, + # np.uint32: torch.int32, + # np.uint64: torch.int64, + np.int8: torch.uint8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float32: torch.float32, + np.float64: torch.float64, +} + +_NOT_SUPPORT_NTYPE = {np.uint16: torch.int16, np.uint32: torch.int32, np.uint64: torch.int64} +_CONVERSION_TYPE = {np.uint16: np.int16, np.uint32: np.int32, np.uint64: np.int64} + + +class ShmBufferBase: + + @abstractmethod + def fill(self, src_arr: Union[np.ndarray, torch.Tensor]) -> None: + raise NotImplementedError -class ShmBuffer(): + @abstractmethod + def get(self) -> Union[np.ndarray, torch.Tensor]: + raise NotImplementedError + + +class ShmBuffer(ShmBufferBase): """ Overview: Shared memory buffer to store numpy array. @@ -78,6 +111,94 @@ def get(self) -> np.ndarray: return data +class ShmBufferCuda(ShmBufferBase): + + def __init__( + self, + dtype: Union[torch.dtype, np.dtype], + shape: Tuple[int], + ctype: Optional[type] = None, + copy_on_get: bool = True, + device: Optional[torch.device] = torch.device('cuda:0') + ) -> None: + """ + Overview: + Use torch.multiprocessing for shared tensor or ndaray between processes. + Arguments: + - dtype (Union[torch.dtype, np.dtype]): dtype of torch.tensor or numpy.ndarray. + - shape (Tuple[int]): Shape of torch.tensor or numpy.ndarray. + - ctype (type): Origin class type, e.g. np.ndarray, torch.Tensor. + - copy_on_get (bool, optional): Can be set to False only if the shared object + is a tenor, otherwise True. + - device (Optional[torch.device], optional): The GPU device where cuda-shared-tensor + is located, the default is cuda:0. + + Raises: + RuntimeError: Unsupported share type by ShmBufferCuda. + """ + if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype + self.ctype = np.ndarray + dtype = dtype.type + if dtype in _NOT_SUPPORT_NTYPE.keys(): + logging.warning( + "Torch tensor unsupport numpy type {}, attempt to do a type conversion, which may lose precision.". + format(dtype) + ) + ttype = _NOT_SUPPORT_NTYPE[dtype] + self.dtype = _CONVERSION_TYPE[dtype] + else: + ttype = _NTYPE_TO_TTYPE[dtype] + self.dtype = dtype + elif isinstance(dtype, torch.dtype): + self.ctype = torch.Tensor + ttype = dtype + else: + raise RuntimeError("The dtype parameter only supports torch.dtype and np.dtype") + + self.copy_on_get = copy_on_get + self.shape = shape + self.device = device + # We don't want the buffer to be involved in the computational graph + with torch.no_grad(): + self.buffer = torch.zeros(reduce(lambda x, y: x * y, shape), dtype=ttype, device=self.device) + + def fill(self, src_arr: Union[np.ndarray, torch.Tensor]) -> None: + if self.ctype is np.ndarray: + if src_arr.dtype.type != self.dtype: + logging.warning( + "Torch tensor unsupport numpy type {}, attempt to do a type conversion, which may lose precision.". + format(self.dtype) + ) + src_arr = src_arr.astype(self.dtype) + tensor = torch.from_numpy(src_arr) + elif self.ctype is torch.Tensor: + tensor = src_arr + else: + raise RuntimeError("Unsopport CUDA-shared-tensor input type:\"{}\"".format(type(src_arr))) + + # If the GPU-a and GPU-b are connected using nvlink, the copy is very fast. + with torch.no_grad(): + self.buffer.copy_(tensor.view(tensor.numel())) + + def get(self) -> Union[np.ndarray, torch.Tensor]: + with torch.no_grad(): + if self.ctype is np.ndarray: + # Because ShmBufferCuda use CUDA memory exchanging data between processes. + # So copy_on_get is necessary for numpy arrays. + re = self.buffer.cpu() + re = re.detach().view(self.shape).numpy() + else: + if self.copy_on_get: + re = self.buffer.clone().detach().view(self.shape) + else: + re = self.buffer.view(self.shape) + + return re + + def __del__(self): + del self.buffer + + class ShmBufferContainer(object): """ Overview: @@ -88,7 +209,8 @@ def __init__( self, dtype: Union[Dict[Any, type], type, np.dtype], shape: Union[Dict[Any, tuple], tuple], - copy_on_get: bool = True + copy_on_get: bool = True, + is_cuda_buffer: bool = False ) -> None: """ Overview: @@ -98,11 +220,15 @@ def __init__( - shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \ multiple buffers; If `tuple`, use single buffer. - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. + - is_cuda_buffer (:obj:`bool`): Whether to use pytorch CUDA shared tensor as the implementation of shm. """ if isinstance(shape, dict): - self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()} + self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get, is_cuda_buffer) for k, v in shape.items()} elif isinstance(shape, (tuple, list)): - self._data = ShmBuffer(dtype, shape, copy_on_get) + if not is_cuda_buffer: + self._data = ShmBuffer(dtype, shape, copy_on_get) + else: + self._data = ShmBufferCuda(dtype, shape, copy_on_get) else: raise RuntimeError("not support shape: {}".format(shape)) self._shape = shape diff --git a/ding/data/tests/test_shm_buffer.py b/ding/data/tests/test_shm_buffer.py index 04334b4799..6316e40b66 100644 --- a/ding/data/tests/test_shm_buffer.py +++ b/ding/data/tests/test_shm_buffer.py @@ -1,20 +1,90 @@ +from ding.data.shm_buffer import ShmBuffer, ShmBufferCuda +from ding.compatibility import torch_ge_1121 + import pytest import numpy as np import timeit -from ding.data.shm_buffer import ShmBuffer -import multiprocessing as mp +import torch +import time -def subprocess(shm_buf): +def subprocess_np_shm(shm_buf): data = np.random.rand(1024, 1024).astype(np.float32) res = timeit.repeat(lambda: shm_buf.fill(data), repeat=5, number=1000) print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res))) +def subprocess_cuda_shared_tensor(shm_buf_np, shm_buf_torch, event_run): + event_run.wait() + rtensor = shm_buf_torch.get() + assert isinstance(rtensor, torch.Tensor) + assert rtensor.device == torch.device('cuda:0') + assert rtensor.dtype == torch.float32 + assert rtensor.sum().item() == 1024 * 1024 + + rarray = shm_buf_np.get() + assert isinstance(rarray, np.ndarray) + assert rarray.dtype == np.dtype(np.float32) + assert rarray.dtype == np.dtype(np.float32) + + res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.get(), repeat=5, number=1000) + print("CUDA-shared-tensor (torch) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) + res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.get(), repeat=5, number=1000) + print("CUDA-shared-tensor (numpy) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) + + del shm_buf_np + del shm_buf_torch + + @pytest.mark.benchmark def test_shm_buffer(): + import multiprocessing as mp data = np.random.rand(1024, 1024).astype(np.float32) shm_buf = ShmBuffer(data.dtype, data.shape, copy_on_get=False) - proc = mp.Process(target=subprocess, args=[shm_buf]) + proc = mp.Process(target=subprocess_np_shm, args=[shm_buf]) proc.start() proc.join() + + +@pytest.mark.benchmark +@pytest.mark.cudatest +@pytest.mark.multiprocesstest +def test_cuda_shm(): + if torch.cuda.is_available() and torch.cuda.device_count() >= 2: + import torch.multiprocessing as mp + ctx = mp.get_context('spawn') + + event_run = ctx.Event() + shm_buf_np = ShmBufferCuda(np.dtype(np.float32), shape=(1024, 1024), copy_on_get=True) + shm_buf_torch = ShmBufferCuda(torch.float32, shape=(1024, 1024), copy_on_get=True) + proc = ctx.Process(target=subprocess_cuda_shared_tensor, args=[shm_buf_np, shm_buf_torch, event_run]) + proc.start() + + ltensor = torch.ones((1024, 1024), dtype=torch.float32).cuda(0 if torch.cuda.device_count() == 1 else 1) + larray = np.random.rand(1024, 1024).astype(np.float32) + shm_buf_torch.fill(ltensor) + shm_buf_np.fill(larray) + + res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.fill(ltensor), repeat=5, number=1000) + print("CUDA-shared-tensor (torch) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) + res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.fill(larray), repeat=5, number=1000) + print("CUDA-shared-tensor (numpy) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) + + rtensor = shm_buf_torch.get() + assert isinstance(rtensor, torch.Tensor) + assert rtensor.device == torch.device('cuda:0') + assert rtensor.shape == ltensor.shape + assert rtensor.dtype == ltensor.dtype + + rarray = shm_buf_np.get() + assert isinstance(rarray, np.ndarray) + assert larray.shape == rarray.shape + assert larray.dtype == rarray.dtype + + event_run.set() + + # Keep producer process running until all consumers exits. + proc.join() + + del shm_buf_np + del shm_buf_torch diff --git a/ding/entry/cli_ditask.py b/ding/entry/cli_ditask.py index 443fe1a6b6..29af0af2ad 100644 --- a/ding/entry/cli_ditask.py +++ b/ding/entry/cli_ditask.py @@ -57,12 +57,36 @@ def print_version(ctx: Context, param: Option, value: bool) -> None: ) @click.option("--platform-spec", type=str, help="Platform specific configure.") @click.option("--platform", type=str, help="Platform type: slurm, k8s.") -@click.option("--mq-type", type=str, default="nng", help="Class type of message queue, i.e. nng, redis.") +@click.option( + "--mq-type", + type=str, + default="nng", + help="Class type of message queue, i.e. nng, redis, torchrpc:cuda, torchrpc:cpu." +) @click.option("--redis-host", type=str, help="Redis host.") @click.option("--redis-port", type=int, help="Redis port.") @click.option("-m", "--main", type=str, help="Main function of entry module.") @click.option("--startup-interval", type=int, default=1, help="Start up interval between each task.") @click.option("--local_rank", type=int, default=0, help="Compatibility with PyTorch DDP") +@click.option( + "--init-method", + type=str, + help="[Torchrpc]: Init method both for init_rpc and init_process_group, please refer to pytorch init_method" +) +@click.option( + "--local-cuda-devices", + type=str, + help='''[Torchrpc]: [Optional] Specifies the device ranks of the GPUs used by the local process, a comma-separated + list of integers.''' +) +@click.option( + "--cuda-device-map", + type=str, + help='''[Torchrpc]: [Optional] Specify device mapping. + Ref: + Format: --cuda-device-map=__,[...] + ''' +) def cli_ditask(*args, **kwargs): return _cli_ditask(*args, **kwargs) @@ -107,9 +131,12 @@ def _cli_ditask( redis_host: str, redis_port: int, startup_interval: int, + init_method: str = None, local_rank: int = 0, platform: str = None, platform_spec: str = None, + local_cuda_devices: str = None, + cuda_device_map: str = None ): # Parse entry point all_args = locals() @@ -145,6 +172,18 @@ def _cli_ditask( if node_ids and not isinstance(node_ids, int): node_ids = node_ids.split(",") node_ids = list(map(lambda i: int(i), node_ids)) + use_cuda = False + if mq_type == "torchrpc:cuda" or mq_type == "torchrpc:cpu": + mq_type, use_cuda = mq_type.split(":") + if use_cuda == "cuda": + use_cuda = True + if local_cuda_devices: + local_cuda_devices = local_cuda_devices.split(",") + local_cuda_devices = list(map(lambda s: s.strip(), local_cuda_devices)) + if cuda_device_map: + cuda_device_map = cuda_device_map.split(",") + cuda_device_map = list(map(lambda s: s.strip(), cuda_device_map)) + Parallel.runner( n_parallel_workers=parallel_workers, ports=ports, @@ -157,5 +196,9 @@ def _cli_ditask( mq_type=mq_type, redis_host=redis_host, redis_port=redis_port, - startup_interval=startup_interval + init_method=init_method, + startup_interval=startup_interval, + use_cuda=use_cuda, + local_cuda_devices=local_cuda_devices, + cuda_device_map=cuda_device_map )(main_func) diff --git a/ding/envs/env_manager/subprocess_env_manager.py b/ding/envs/env_manager/subprocess_env_manager.py index 1648981f03..fdcc61de17 100644 --- a/ding/envs/env_manager/subprocess_env_manager.py +++ b/ding/envs/env_manager/subprocess_env_manager.py @@ -1,5 +1,6 @@ from typing import Any, Union, List, Tuple, Dict, Callable, Optional from multiprocessing import connection, get_context +# from torch.multiprocessing import connection, get_context from collections import namedtuple from ditk import logging import platform @@ -12,6 +13,7 @@ import cloudpickle import numpy as np import treetensor.numpy as tnp +import treetensor.torch as ttorch from easydict import EasyDict from types import MethodType from ding.data import ShmBufferContainer, ShmBuffer @@ -70,6 +72,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager): retry_waiting_time=0.1, # subprocess specified args shared_memory=True, + cuda_shared_memory=False, copy_on_get=True, context='spawn' if platform.system().lower() == 'windows' else 'fork', wait_num=2, @@ -97,6 +100,7 @@ def __init__( """ super().__init__(env_fn, cfg) self._shared_memory = self._cfg.shared_memory + self._cuda_shared_memory = self._cfg.cuda_shared_memory if self._shared_memory else False self._copy_on_get = self._cfg.copy_on_get self._context = self._cfg.context self._wait_num = self._cfg.wait_num @@ -134,7 +138,9 @@ def _create_state(self) -> None: shape = obs_space.shape dtype = obs_space.dtype self._obs_buffers = { - env_id: ShmBufferContainer(dtype, shape, copy_on_get=self._copy_on_get) + env_id: ShmBufferContainer( + dtype, shape, copy_on_get=self._copy_on_get, is_cuda_buffer=self._cuda_shared_memory + ) for env_id in range(self.env_num) } else: @@ -148,7 +154,11 @@ def _create_state(self) -> None: def _create_env_subprocess(self, env_id): # start a new one - ctx = get_context(self._context) + if self._cuda_shared_memory: + import torch.multiprocessing as mp + ctx = mp.get_context('spawn') + else: + ctx = get_context(self._context) self._pipe_parents[env_id], self._pipe_children[env_id] = ctx.Pipe() self._subprocesses[env_id] = ctx.Process( # target=self.worker_fn, @@ -705,6 +715,7 @@ class SyncSubprocessEnvManager(AsyncSubprocessEnvManager): retry_waiting_time=0.1, # subprocess specified args shared_memory=True, + cuda_shared_memory=False, copy_on_get=True, context='spawn' if platform.system().lower() == 'windows' else 'fork', wait_num=float("inf"), # inf mean all the environments @@ -802,7 +813,7 @@ class SubprocessEnvManagerV2(SyncSubprocessEnvManager): """ @property - def ready_obs(self) -> tnp.array: + def ready_obs(self) -> Union[tnp.array, torch.Tensor]: """ Overview: Get the ready (next) observation in ``tnp.array`` type, which is uniform for both async/sync scenarios. @@ -822,7 +833,10 @@ def ready_obs(self) -> tnp.array: ) time.sleep(0.001) sleep_count += 1 - return tnp.stack([tnp.array(self._ready_obs[i]) for i in self.ready_env]) + if not self._cuda_shared_memory: + return tnp.stack([tnp.array(self._ready_obs[i]) for i in self.ready_env]) + else: + return ttorch.stack([ttorch.tensor(self._ready_obs[i]) for i in self.ready_env]) def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]: """ @@ -846,5 +860,16 @@ def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]: # in order to call them as attribute (e.g. timestep.xxx), such as ``TimeLimit.truncated`` in cartpole info info = make_key_as_identifier(info) info = remove_illegal_item(info) - new_data.append(tnp.array({'obs': obs, 'reward': reward, 'done': done, 'info': info, 'env_id': env_id})) + if not self._cuda_shared_memory: + new_data.append(tnp.array({'obs': obs, 'reward': reward, 'done': done, 'info': info, 'env_id': env_id})) + else: + new_data.append( + ttorch.tensor({ + 'obs': obs, + 'reward': reward, + 'done': done, + 'info': info, + 'env_id': env_id + }) + ) return new_data diff --git a/ding/framework/__init__.py b/ding/framework/__init__.py index 72c23d0475..fd489588e7 100644 --- a/ding/framework/__init__.py +++ b/ding/framework/__init__.py @@ -1,6 +1,6 @@ from .context import Context, OnlineRLContext, OfflineRLContext -from .task import Task, task, VoidMiddleware -from .parallel import Parallel +from .task import Task, task, VoidMiddleware, enable_async +from .parallel import Parallel, MQType from .event_loop import EventLoop from .supervisor import Supervisor from easydict import EasyDict diff --git a/ding/framework/message_queue/README.md b/ding/framework/message_queue/README.md new file mode 100644 index 0000000000..3267dbecfd --- /dev/null +++ b/ding/framework/message_queue/README.md @@ -0,0 +1,13 @@ +# Notes on using torchrpc + +## Problems you may encounter + +Message queue of Torchrpc uses [tensorpipe](https://github.com/pytorch/tensorpipe) as a communication backend, a high-performance modular tensor-p2p communication library. However, several tensorpipe defects have been found in the test, which may make it difficult for you to use it. + +### 1. container environment + +Tensorpipe is not container aware. Processes can find themselves on the same physical machine through `/proc/sys/kernel/random/boot_id` ,but because in separated pod/container, they cannot use means of communication such as CUDA ipc. When tensorpipe finds that these communication methods cannot be used, it will report an error and exit. + +### 2. RDMA and fork subprocess + +Tensorpipe does not consider the case of calling [fork(2)](https://man7.org/linux/man-pages/man2/fork.2.html) when using RDMA. If the corresponding initialization measures are not performed when using RDMA, using fork will cause serious problems, refer to [here](https://www.rdmamojo.com/2012/05/24/ibv_fork_init/). Therefore, if you start ditask in the IB/RoCE network environment, please specify the environment variables `IBV_FORK_SAFE=1` and `RDMAV_FORK_SAFE=1` , so that ibverbs will automatically initialize fork support. \ No newline at end of file diff --git a/ding/framework/message_queue/__init__.py b/ding/framework/message_queue/__init__.py index 7cbbbcd93c..3cedbe11d7 100644 --- a/ding/framework/message_queue/__init__.py +++ b/ding/framework/message_queue/__init__.py @@ -1,3 +1,4 @@ from .mq import MQ from .redis import RedisMQ from .nng import NNGMQ +from .torch_rpc import TORCHRPCMQ, DeviceMap diff --git a/ding/framework/message_queue/mq.py b/ding/framework/message_queue/mq.py index 4386882020..37a6b61676 100644 --- a/ding/framework/message_queue/mq.py +++ b/ding/framework/message_queue/mq.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional class MQ: @@ -31,12 +31,15 @@ def publish(self, topic: str, data: bytes) -> None: """ raise NotImplementedError - def subscribe(self, topic: str) -> None: + def subscribe(self, topic: str, fn: Optional[callable] = None, is_once: Optional[bool] = False) -> None: """ Overview: Subscribe to the topic. Arguments: - topic (:obj:`str`): Topic + - fn (:obj:`Optional[callable]`): The message handler, if the communication library + implements event_loop, it can bypass Parallel() and calling this function by itself. + - is_once (:obj:`bool`): Whether Topic will only be called once. """ raise NotImplementedError diff --git a/ding/framework/message_queue/nng.py b/ding/framework/message_queue/nng.py index 379601b0ed..5298fc0a55 100644 --- a/ding/framework/message_queue/nng.py +++ b/ding/framework/message_queue/nng.py @@ -39,7 +39,7 @@ def publish(self, topic: str, data: bytes) -> None: data = topic.encode() + data self._sock.send(data) - def subscribe(self, topic: str) -> None: + def subscribe(self, topic: str, fn: Optional[callable] = None, is_once: Optional[bool] = False) -> None: return def unsubscribe(self, topic: str) -> None: diff --git a/ding/framework/message_queue/perfs/perf_nng.py b/ding/framework/message_queue/perfs/perf_nng.py new file mode 100644 index 0000000000..d597518b54 --- /dev/null +++ b/ding/framework/message_queue/perfs/perf_nng.py @@ -0,0 +1,274 @@ +import pickle +import multiprocessing as mp +import argparse +import os +import time +import torch +import numpy as np +import click +import struct + +from time import sleep +from threading import Thread +from ding.framework.message_queue.nng import NNGMQ +from ditk import logging +from ding.framework.parallel import Parallel +from ding.utils.comm_perf_helper import byte_beauty_print, time_perf_avg, print_timer_result_csv +from ding.utils import EasyTimer, WatchDog + +logging.getLogger().setLevel(logging.INFO) +REPEAT = 10 +LENGTH = 5 +EXP_NUMS = 2 +UNIT_SIZE_LIST = [64, 1024, 64 * 1024, 512 * 1024, 2 * 1024 * 1024] + + +@click.command(context_settings=dict(help_option_names=['-h', '--help'])) +@click.option("--ports", type=str, default="50515") +@click.option("--attach-to", type=str, help="The addresses to connect to.") +@click.option("--address", type=str, help="The address to listen to (without port).") +@click.option("--labels", type=str, help="Labels.") +@click.option("--node-ids", type=str, help="Candidate node ids.") +def handle_args(*args, **kwargs): + return nng_perf_main(*args, **kwargs) + + +def pack_time(data, value): + if value: + return struct.pack('d', value) + "::".encode() + data + else: + return struct.pack('d', value) + + +def unpack_time(value): + return struct.unpack('=d', value)[0] + + +def nng_dist_main(labels, node_id, listen_to, attach_to, *arg, **kwargs) -> None: + """ + Overview: + Since nng message reception may be out of order, and nng + does not have a handshake, the sender may start + sending messages and timing before the receiver is ready. + So this function does the corresponding work. + """ + mq = NNGMQ(listen_to=listen_to, attach_to=attach_to) + mq.listen() + label = labels.pop() + rank = 0 + future_dict = dict() + start_tag = [] + finish_tag = [] + + def send_t(topic, data=None): + try: + if not data: + data = [0, 0] + data = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) + mq.publish(topic, data) + logging.debug("send topic {}".format(topic)) + except Exception as e: + logging.error("send error at rank:{} label:\"{}\", topic:\"{}\", error: {}".format(rank, label, topic, e)) + + def recv_loop(): + while True: + topic, data = mq.recv() + if topic == "z": + # perf_nng_detail recv callback. + timestamps, data = data.split(b"::", maxsplit=1) + h2d_timer = EasyTimer(cuda=True) + pickle_timer = EasyTimer(cuda=False) + + with pickle_timer: + data = pickle.loads(data) + data, idx = data[0], data[1] + + with h2d_timer: + data = data.cuda(0) + + data = pickle.dumps([timestamps, idx], protocol=pickle.HIGHEST_PROTOCOL) + time_res = pack_time(data, pickle_timer.value) + time_res = pack_time(time_res, h2d_timer.value) + + mq.publish("k", time_res) + continue + elif topic == "k": + # perf_nng_detail send callback. + h2d_time, pickle_time, data = data.split(b"::", maxsplit=2) + data = pickle.loads(data) + timestamps, idx = data[0], data[1] + future_dict['perf_finsh'] = (unpack_time(h2d_time), unpack_time(pickle_time), unpack_time(timestamps)) + future_dict[idx] = 1 + continue + else: + # Callback functions for other tests. + data = pickle.loads(data) + data, idx = data[0], data[1] + if topic == "t": + assert isinstance(data, torch.Tensor) + data = data.cuda(0) + torch.cuda.synchronize(0) + pass + elif topic == "d": + assert isinstance(data, dict) + for k, v in data.items(): + data[k] = v.cuda(0) + torch.cuda.synchronize(0) + elif topic == "a": + if idx not in future_dict.keys(): + raise RuntimeError("Unkown idx") + future_dict[idx] = 1 + continue + elif topic == "s": + if label == 'collector': + send_t("s") + elif label == 'learner': + start_tag.append(1) + continue + elif topic == "f": + finish_tag.append(1) + return + else: + raise RuntimeError("Unkown topic") + + send_t("a", ["", idx]) + + def irendezvous(): + timeout_killer = WatchDog(3) + timeout_killer.start() + send_t("s") + while len(start_tag) == 0: + time.sleep(0.05) + timeout_killer.stop() + + listen_thread = Thread(target=recv_loop, name="recv_loop", daemon=True) + listen_thread.start() + + if label == 'learner': + while True: + try: + irendezvous() + except Exception as e: + logging.warning("timeout for irendezvous") + else: + break + + if label == 'learner': + + for size in UNIT_SIZE_LIST: + unit_size = size * LENGTH + gpu_data = torch.ones(unit_size).cuda(rank) + time_list = [list() for i in range(EXP_NUMS)] + size_lists = [[size] for i in range(LENGTH)] + send_func_list = [] + logging.info("Data size: {:.2f} {}".format(*byte_beauty_print(unit_size * 4))) + tensor_dict = dict() + for j, size_list in enumerate(size_lists): + tensor_dict[str(j)] = torch.ones(size_list).cuda(rank) + + @time_perf_avg(1, REPEAT, cuda=True) + def nng_tensor_sender_1(idx): + future_dict[idx] = 0 + send_t("t", [gpu_data.cpu(), idx]) + while future_dict[idx] == 0: + time.sleep(0.03) + + @time_perf_avg(1, REPEAT, cuda=True) + def nng_tensor_sender_2(idx): + tmp_dict = dict() + future_dict[idx] = 0 + for key, value in tensor_dict.items(): + tmp_dict[key] = value.cpu() + send_t("d", [tmp_dict, idx]) + while future_dict[idx] == 0: + time.sleep(0.03) + + def perf_nng_detail(idx): + future_dict[idx] = 0 + h2d_timer = EasyTimer(cuda=True) + pickle_timer = EasyTimer(cuda=False) + + with h2d_timer: + data = gpu_data.cpu() + + with pickle_timer: + data = pickle.dumps([data, idx], protocol=pickle.HIGHEST_PROTOCOL) + + data = pack_time(data, time.time()) + mq.publish("z", data) + + while future_dict[idx] == 0: + time.sleep(0.03) + + peer_h2d_time, peer_pickle_time, timestamps = future_dict['perf_finsh'] + total_time = time.time() - timestamps + # Serialization time + pickle_time = peer_pickle_time + pickle_timer.value + # H2D/D2H time + pcie_time = peer_h2d_time + h2d_timer.value + # TCP I/O time + IO_time = total_time - pickle_time - pcie_time + logging.info( + "Detailed: total:[{:.4f}]ms, pickle:[{:.4f}]ms, H2D/D2H:[{:.4f}]ms, I/O:[{:.4f}]ms".format( + total_time, pickle_time, pcie_time, IO_time + ) + ) + # print("{:.4f}, {:.4f}, {:.4f}, {:.4f}".format(total_time, pickle_time, pcie_time, IO_time)) + + send_func_list.append(nng_tensor_sender_1) + send_func_list.append(nng_tensor_sender_2) + + for i in range(len(send_func_list)): + for j in range(REPEAT): + send_func_list[i](j, i + j) + + # Determine the time-consuming of each stage of nng. + perf_nng_detail(0) + + # Do some proper cleanup to prevent cuda memory overflow + torch.cuda.empty_cache() + + if label == 'learner': + send_t("f") + finish_tag.append(1) + + while len(finish_tag) == 0: + time.sleep(0.1) + + print_timer_result_csv() + + +def nng_perf_main(ports: str, attach_to: str, address: str, labels: str, node_ids: str): + if not isinstance(ports, int): + ports = ports.split(",") + ports = list(map(lambda i: int(i), ports)) + ports = ports[0] if len(ports) == 1 else ports + if attach_to: + attach_to = attach_to.split(",") + attach_to = list(map(lambda s: s.strip(), attach_to)) + if labels: + labels = labels.split(",") + labels = set(map(lambda s: s.strip(), labels)) + if node_ids and not isinstance(node_ids, int): + node_ids = node_ids.split(",") + node_ids = list(map(lambda i: int(i), node_ids)) + + runner_params = Parallel._nng_args_parser( + n_parallel_workers=1, + ports=ports, + protocol="tcp", + attach_to=attach_to, + address=address, + labels=labels, + node_ids=node_ids, + ) + logging.debug(runner_params) + nng_dist_main(**runner_params[0]) + + +# Usages: +# CUDA_VISIBLE_DEVICES=0 python perf_nng.py --node-ids 0 --labels learner --ports 12345 --address 0.0.0.0 +# CUDA_VISIBLE_DEVICES=1 python perf_nng.py --node-ids 1 --labels collector --address 127.0.0.1 \ +# --ports 12355 --attach-to tcp://0.0.0.0:12345 +if __name__ == "__main__": + handle_args() diff --git a/ding/framework/message_queue/perfs/perf_shm.py b/ding/framework/message_queue/perfs/perf_shm.py new file mode 100644 index 0000000000..234f49213b --- /dev/null +++ b/ding/framework/message_queue/perfs/perf_shm.py @@ -0,0 +1,141 @@ +from typing import TYPE_CHECKING, Any, List, Union, Dict, Optional, Callable + +from ditk import logging +from ding.framework.supervisor import RecvPayload, SendPayload, Supervisor, ChildType +from ding.envs.env_manager.subprocess_env_manager import ShmBufferContainer, ShmBuffer +from ding.utils.comm_perf_helper import tensor_size_beauty_print, byte_beauty_print, \ + dtype_2_byte, TENSOR_SIZE_LIST, print_timer_result_csv + +import torch +import numpy as np +import time +import argparse + +LENGTH = 5 +REPEAT = 10 +UNIT_SIZE_LIST = [64, 1024, 64 * 1024, 512 * 1024, 2 * 1024 * 1024] +logging.getLogger().setLevel(logging.INFO) + + +def shm_callback(payload: RecvPayload, buffers: Any): + # Step4: shared memory -> np.array + np_tensor = buffers[payload.data["idx"]].get() + # Step5: np.array -> cpu tensor + tensor = torch.from_numpy(np_tensor) + # Step6: cpu tensor -> gpu tensor + tensor = tensor.cuda(0) + torch.cuda.synchronize(0) + + +def cuda_shm_callback(payload: RecvPayload, buffers: Any): + # Step2: gpu shared tensor -> gpu tensor + tensor = buffers[payload.data["idx"]].get() + assert tensor.device == torch.device('cuda:0') + # Step3: gpu tensor(cuda:0) -> gpu tensor(cuda:1) + tensor = tensor.to(1) + torch.cuda.synchronize(1) + assert tensor.device == torch.device('cuda:1') + + +class Recvier: + + def step(self, idx: int, __start_time): + return {"idx": idx, "start_time": __start_time} + + +class ShmSupervisor(Supervisor): + + def __init__(self, gpu_tensors, buffers, ctx, is_cuda_buffer): + super().__init__(type_=ChildType.PROCESS, mp_ctx=ctx) + self.gpu_tensors = gpu_tensors + self.buffers = buffers + self.time_list = [] + self._time_list = [] + self._is_cuda_buffer = is_cuda_buffer + if not is_cuda_buffer: + _shm_callback = shm_callback + else: + _shm_callback = cuda_shm_callback + self.register(Recvier, shm_buffer=self.buffers, shm_callback=_shm_callback) + super().start_link() + + def _send_recv_callback(self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None): + idx = payload.data["idx"] + __start_time = payload.data["start_time"] + __end_time = time.time() + self.time_list.append(float(__end_time - __start_time) * 1000.0) + + def step(self): + # Do not use Queue to send large data, use shm. + for i, size in enumerate(UNIT_SIZE_LIST): + for j in range(REPEAT): + __start_time = time.time() + + if not self._is_cuda_buffer: + # Numpy shm buffer: + # Step1: gpu tensor -> cpu tensor + tensor = self.gpu_tensors[i].cpu() + # Step2: cpu tensor-> np.array + np_tensor = tensor.numpy() + # Step3: np.array -> shared memory + self.buffers[i].fill(np_tensor) + else: + # Cuda shared tensor + # Step1: gpu tensor -> gpu shared tensor + self.buffers[i].fill(self.gpu_tensors[i]) + + payload = SendPayload(proc_id=0, method="step", args=[i, __start_time]) + send_payloads = [payload] + + self.send(payload) + self.recv_all(send_payloads, ignore_err=True, callback=self._send_recv_callback) + + _avg_time = sum(self.time_list) / len(self.time_list) + self._time_list.append(_avg_time) + self.time_list.clear() + logging.info( + "Data size {:.2f} {} , repeat {}, avg RTT {:.4f} ms".format( + *byte_beauty_print(UNIT_SIZE_LIST[i] * 4 * LENGTH), REPEAT, _avg_time + ) + ) + + for t in self._time_list: + print("{:.4f},".format(t), end="") + print("") + + +def shm_perf_main(test_type: str): + gpu_tensors = list() + buffers = dict() + + if test_type == "shm": + import multiprocessing as mp + use_cuda_buffer = False + elif test_type == "cuda_ipc": + use_cuda_buffer = True + import torch.multiprocessing as mp + + ctx = mp.get_context('spawn') + + for i, size in enumerate(UNIT_SIZE_LIST): + unit_size = size * LENGTH + gpu_tensors.append(torch.ones(unit_size).cuda(0)) + if not use_cuda_buffer: + buffers[i] = ShmBufferContainer(np.float32, (unit_size, ), copy_on_get=True, is_cuda_buffer=False) + else: + buffers[i] = ShmBufferContainer(torch.float32, (unit_size, ), copy_on_get=True, is_cuda_buffer=True) + + sv = ShmSupervisor( + gpu_tensors=gpu_tensors, buffers=buffers, ctx=mp.get_context('spawn'), is_cuda_buffer=use_cuda_buffer + ) + sv.step() + del sv + + +# Usages: +# python perf_shm.py --test_type ["shm"|"cuda_ipc"] +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Test torch rpc') + parser.add_argument('--test_type', type=str) + args, _ = parser.parse_known_args() + shm_perf_main(args.test_type) diff --git a/ding/framework/message_queue/perfs/perf_torchrpc_nccl.py b/ding/framework/message_queue/perfs/perf_torchrpc_nccl.py new file mode 100644 index 0000000000..67fbb73e46 --- /dev/null +++ b/ding/framework/message_queue/perfs/perf_torchrpc_nccl.py @@ -0,0 +1,278 @@ +import time +import torch +import os +import argparse +import torch.distributed as dist +import treetensor.torch as ttorch + +from dataclasses import dataclass +from queue import Empty +from typing import TYPE_CHECKING, List, Dict, Union +from ditk import logging + +from ding.utils.data.structure.lifo_deque import LifoDeque +from ding.framework.message_queue.torch_rpc import DeviceMap, TORCHRPCMQ, RPCEvent +from ding.utils.comm_perf_helper import tensor_size_beauty_print, byte_beauty_print, \ + dtype_2_byte, DO_PERF, time_perf_avg, time_perf_once, print_timer_result_csv + +LENGTH = 5 +REPEAT = 2 +MAX_EXP_NUMS = 10 +UNIT_SIZE_LIST = [64, 1024, 64 * 1024, 512 * 1024, 2 * 1024 * 1024] + + +@dataclass +class SendInfo: + send_map_dict: Dict = None + sending_flag: int = 0 + + +# Global vars definition is here: +global mq +global global_send_info_dict +mq = None + +global_send_info_dict = dict() + + +def remote_mq_entrance(topic, *args, **kwargs): + global mq + mq.rpc_event_router(topic, *args, **kwargs) + + +def dict_tensor_send_by_key( + key: str, tensor_id: int, tensor: torch.Tensor, nums: int, send_id: int, use_cuda: bool +) -> None: + """ + Overview: + For data structures that use dict to store tensor, such as dict[key:tensor] or treetensor, + this function can be used. Each key is transmitted using one rpc, and the rpc transmission + of each key is asynchronous. + Arguments: + - key (str): Key in dict. + - tensor_id (int): The sending tensor ID during one dict/treetensor rpc transmission. + - tensor (torch.tensor): The tensor to be sent. + - nums (int): The total number of sent tensors. + - send_id (int): The ID of this dict/treetensor rpc transmission. + """ + global global_send_info_dict + send_info_dict = global_send_info_dict + send_info = None + + assert isinstance(key, str) + assert isinstance(tensor_id, int) + assert isinstance(tensor, torch.Tensor) + assert isinstance(nums, int) + assert isinstance(send_id, int) + assert isinstance(use_cuda, bool) + + if tensor_id == 0: + send_info = SendInfo() + send_info.send_map_dict = dict() + send_info_dict[send_id] = send_info + else: + while True: + if send_id in send_info_dict.keys(): + send_info = send_info_dict[send_id] + if send_info is not None: + break + + assert isinstance(send_info, SendInfo) + + if key in send_info.send_map_dict.keys(): + raise RuntimeError("Multiple state_dict's key \"{}\" received!".format(key)) + + send_info.send_map_dict[key] = tensor + + if tensor_id == nums - 1: + while len(send_info.send_map_dict) != nums: + time.sleep(0.01) + + send_info_dict.clear() + if use_cuda: + torch.cuda.synchronize(0) + return + + +def send_dummy(playload: Union[torch.Tensor, Dict], use_cuda: bool, *args) -> None: + assert isinstance(use_cuda, bool) + if use_cuda: + torch.cuda.synchronize(0) + return + + +def dict_tensor_send(mq: TORCHRPCMQ, state_dict: Dict, send_id: int, use_cuda: bool) -> None: + future_list = [] + for tensor_id, (key, value) in enumerate(state_dict.items()): + future_list.append(mq.publish("DICT_TENSOR_SEND", key, tensor_id, value, len(state_dict), send_id, use_cuda)) + + for future in future_list: + future.wait() + + +def perf_torch_rpc(use_cuda=True): + global LENGTH + global UNIT_SIZE_LIST + if use_cuda: + device = "cuda:0" + else: + device = "cpu" + + for i, unit_size in enumerate(UNIT_SIZE_LIST): + unit_tensor = torch.ones([unit_size * LENGTH]).to(device) + tensor_dict = {} + for j in range(LENGTH): + tensor_dict[str(j)] = torch.ones(unit_size).to(device) + + if use_cuda: + torch.cuda.synchronize(0) + + @time_perf_avg(1, REPEAT, cuda=use_cuda) + def one_shot_rpc(): + dict_tensor_send(mq, {'test': unit_tensor}, i, use_cuda) + + @time_perf_avg(1, REPEAT, cuda=use_cuda) + def one_shot_rpc_with_dict(): + dict_tensor_send(mq, tensor_dict, i, use_cuda) + + @time_perf_avg(1, REPEAT, cuda=use_cuda) + def split_chunk_rpc(): + re = mq.publish(RPCEvent.CUSTOM_FUNCRION_RPC, {'test': unit_tensor}, use_cuda, custom_method=send_dummy) + re.wait() + + @time_perf_avg(1, REPEAT, cuda=use_cuda) + def split_chunk_rpc_with_dict(): + re = mq.publish(RPCEvent.CUSTOM_FUNCRION_RPC, tensor_dict, use_cuda, custom_method=send_dummy) + re.wait() + + logging.debug("Size {:.2f} {}".format(*byte_beauty_print(unit_size * LENGTH * 4))) + + for idx in range(REPEAT): + one_shot_rpc(idx) + one_shot_rpc_with_dict(idx) + split_chunk_rpc(idx) + split_chunk_rpc_with_dict(idx) + + if use_cuda: + torch.cuda.empty_cache() + + +def perf_nccl(global_rank: int, use_cuda=True): + if use_cuda: + device = "cuda:0" + else: + device = "cpu" + ack_tensor = torch.ones(10).to(device) + + if global_rank == 0: + # Warm up recving + dist.recv(tensor=ack_tensor, src=1) + if use_cuda: + torch.cuda.synchronize(0) + + for i, unit_size in enumerate(UNIT_SIZE_LIST): + payload = torch.ones([unit_size * LENGTH]).to(device) + + @time_perf_avg(1, REPEAT, cuda=True) + def test_case_nccl(payload): + dist.send(tensor=payload, dst=1, tag=i) + + logging.debug("Size {:.2f} {}".format(*byte_beauty_print(unit_size * LENGTH * 4))) + + for idx in range(REPEAT): + test_case_nccl(idx, payload) + else: + # Warm up sending + dist.send(tensor=ack_tensor, dst=0) + if use_cuda: + torch.cuda.synchronize(0) + + for i, unit_size in enumerate(UNIT_SIZE_LIST): + recvbuffer = torch.ones([unit_size * LENGTH]).to(device) + for j in range(REPEAT): + dist.recv(tensor=recvbuffer, src=0, tag=i) + if use_cuda: + torch.cuda.synchronize(0) + + +def rpc_model_exchanger(rank: int, init_method: str, test_nccl: bool = False, use_cuda: bool = True): + global mq + global dict_tensor_send_by_key + global remote_mq_entrance + from ding.framework.parallel import Parallel + + logging.getLogger().setLevel(logging.DEBUG) + if test_nccl: + dist.init_process_group("nccl", rank=rank, world_size=2, init_method=init_method) + params = Parallel._torchrpc_args_parser( + n_parallel_workers=1, + attach_to=[1] if rank == 0 else [], + node_ids=[rank], + init_method=init_method, + use_cuda=use_cuda, + async_rpc=True, + async_backend_polling=False, + remote_parallel_entrance=remote_mq_entrance + )[0] + logging.debug(params) + mq = TORCHRPCMQ(**params) + mq.show_device_maps() + + # Because the dict_tensor_send_by_key() relies on global variables, we have to register it. + mq.subscribe("DICT_TENSOR_SEND", dict_tensor_send_by_key) + mq.listen() + + # In order to prevent deadlock caused by mixed use of "torch.cuda.synchronize" between + # nccl and torchrpc, we test the two backend separately. + if rank == 1: + # Receiver ready for testing nccl + if test_nccl: + perf_nccl(rank) + # Receiver join to wait sender to send shutdown signal. + mq.wait_for_shutdown() + elif rank == 0: + # Sender test torch rpc. + perf_torch_rpc(use_cuda=use_cuda) + # Sender test nccl. + if test_nccl: + perf_nccl(rank) + # Print test results. + print_timer_result_csv() + # Sender send finish signal. + mq.require_to_shutdown("Node_1") + # Sender clean resources. + mq.stop() + + +# Usage: +# CUDA_VISIBLE_DEVICES=0 python perf_torchrpc_nccl.py --rank=0 +# CUDA_VISIBLE_DEVICES=1 python perf_torchrpc_nccl.py --rank=1 +# +# Note: +# If you are in a container, please ensure that your /dev/shm is large enough. +# If there is a strange core or bug, please check if /dev/shm is full. +# If so, please try to clear it manually: +# /dev/shm/nccl* +# /dev/shm/cuda.shm.* +# /dev/shm/torch_* +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Test torch rpc') + parser.add_argument('--rank', type=int) + parser.add_argument('--init-method', type=str, default="tcp://127.0.0.1:12347") + parser.add_argument('--test_nccl', type=bool, default=False) + parser.add_argument('--use_cuda', type=bool, default=False) + args, _ = parser.parse_known_args() + + if args.use_cuda: + if "CUDA_VISIBLE_DEVICES" in os.environ: + logging.info("CUDA_VISIBLE_DEVICES: {}".format(os.environ['CUDA_VISIBLE_DEVICES'])) + else: + logging.info("Not set CUDA_VISIBLE_DEVICES!") + + logging.info( + "CUDA is enable:{}, nums of GPU: {}, current device: {}".format( + torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.current_device() + ) + ) + + rpc_model_exchanger(args.rank, args.init_method, args.test_nccl, args.use_cuda) diff --git a/ding/framework/message_queue/perfs/tests/test_perf_nng.py b/ding/framework/message_queue/perfs/tests/test_perf_nng.py new file mode 100644 index 0000000000..343d5d080d --- /dev/null +++ b/ding/framework/message_queue/perfs/tests/test_perf_nng.py @@ -0,0 +1,14 @@ +from ding.framework.message_queue.perfs.perf_nng import nng_perf_main +import multiprocessing as mp +import pytest + + +@pytest.mark.mqbenchmark +@pytest.mark.multiprocesstest +def test_nng(): + params = [ + ("12376", None, "127.0.0.1", "learner", "0"), ("12378", "tcp://127.0.0.1:12376", "127.0.0.1", "collector", "1") + ] + ctx = mp.get_context("spawn") + with ctx.Pool(processes=2) as pool: + pool.starmap(nng_perf_main, params) diff --git a/ding/framework/message_queue/perfs/tests/test_perf_shm.py b/ding/framework/message_queue/perfs/tests/test_perf_shm.py new file mode 100644 index 0000000000..2be1a2a047 --- /dev/null +++ b/ding/framework/message_queue/perfs/tests/test_perf_shm.py @@ -0,0 +1,20 @@ +from ding.framework.message_queue.perfs.perf_shm import shm_perf_main +import multiprocessing as mp +import pytest +import torch + + +@pytest.mark.mqbenchmark +@pytest.mark.cudatest +@pytest.mark.multiprocesstest +def test_shm_numpy_shm(): + if torch.cuda.is_available(): + shm_perf_main("shm") + + +@pytest.mark.mqbenchmark +@pytest.mark.cudatest +@pytest.mark.multiprocesstest +def test_shm_cuda_shared_tensor(): + if torch.cuda.is_available() and torch.cuda.device_count() >= 2: + shm_perf_main("cuda_ipc") diff --git a/ding/framework/message_queue/perfs/tests/test_perf_torchrpc_nccl.py b/ding/framework/message_queue/perfs/tests/test_perf_torchrpc_nccl.py new file mode 100644 index 0000000000..8af00e8a2a --- /dev/null +++ b/ding/framework/message_queue/perfs/tests/test_perf_torchrpc_nccl.py @@ -0,0 +1,18 @@ +from ding.framework.message_queue.perfs.perf_torchrpc_nccl import rpc_model_exchanger +from ding.compatibility import torch_ge_1121 +import multiprocessing as mp +import pytest +import torch +import platform + + +@pytest.mark.mqbenchmark +@pytest.mark.cudatest +@pytest.mark.multiprocesstest +def test_perf_torchrpc_nccl(): + if platform.system().lower() != 'windows' and torch.cuda.is_available(): + if torch_ge_1121() and torch.cuda.device_count() >= 2: + params = [(0, "tcp://127.0.0.1:12387", False, True), (1, "tcp://127.0.0.1:12387", False, True)] + ctx = mp.get_context("spawn") + with ctx.Pool(processes=2) as pool: + pool.starmap(rpc_model_exchanger, params) diff --git a/ding/framework/message_queue/redis.py b/ding/framework/message_queue/redis.py index 9cbf10e8a6..69860e3242 100644 --- a/ding/framework/message_queue/redis.py +++ b/ding/framework/message_queue/redis.py @@ -1,7 +1,7 @@ import uuid from ditk import logging from time import sleep -from typing import Tuple +from typing import Tuple, Optional import redis from ding.framework.message_queue.mq import MQ @@ -34,7 +34,7 @@ def publish(self, topic: str, data: bytes) -> None: data = self._id + b"::" + data self._client.publish(topic, data) - def subscribe(self, topic: str) -> None: + def subscribe(self, topic: str, fn: Optional[callable] = None, is_once: Optional[bool] = False) -> None: self._sub.subscribe(topic) def unsubscribe(self, topic: str) -> None: diff --git a/ding/framework/message_queue/tests/test_torch_rpc.py b/ding/framework/message_queue/tests/test_torch_rpc.py new file mode 100644 index 0000000000..1adf979021 --- /dev/null +++ b/ding/framework/message_queue/tests/test_torch_rpc.py @@ -0,0 +1,227 @@ +from ding.framework.message_queue.torch_rpc import DeviceMap, TORCHRPCMQ, DEFAULT_DEVICE_MAP_NUMS +from torch.distributed import rpc +from multiprocessing import Pool, get_context +from ding.compatibility import torch_ge_1121 +from ditk import logging + +import pytest +import torch +import platform +import time + +mq = None +recv_tensor_list = [None, None, None, None] + + +def remote_mq_entrance(topic, *args, **kwargs): + global mq + mq.rpc_event_router(topic, *args, **kwargs) + + +def torchrpc(rank): + global mq + global recv_tensor_list + mq = None + recv_tensor_list = [None, None, None, None] + logging.getLogger().setLevel(logging.DEBUG) + name_list = ["A", "B", "C", "D"] + + if rank == 0: + attach_to = name_list[1:] + else: + attach_to = None + + mq = TORCHRPCMQ( + rpc_name=name_list[rank], + global_rank=rank, + init_method="tcp://127.0.0.1:12398", + remote_parallel_entrance=remote_mq_entrance, + attach_to=attach_to, + async_rpc=False, + use_cuda=False + ) + + def fn1(tensor: torch.Tensor) -> None: + global recv_tensor_list + global mq + recv_tensor_list[0] = tensor + assert recv_tensor_list[0].sum().item() == 1000 + mq.publish("RANK_N_SEND", torch.ones(10), mq.global_rank) + + def fn2(tensor: torch.Tensor, rank) -> None: + global recv_tensor_list + recv_tensor_list[rank] = tensor + assert recv_tensor_list[rank].sum().item() == 10 + + mq.subscribe(topic="RANK_0_SEND", fn=fn1) + mq.subscribe(topic="RANK_N_SEND", fn=fn2) + mq.listen() + + if rank == 0: + mq.publish("RANK_0_SEND", torch.ones(1000)) + + mq._rendezvous_until_world_size(4) + all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() + rpc.api._barrier([worker.name for worker in all_worker_info]) + + mq.unsubscribe("RANK_0_SEND") + assert "RANK_0_SEND" not in mq._rpc_events + + if rank == 0: + mq.publish("RANK_0_SEND", torch.ones(1000)) + + mq._rendezvous_until_world_size(4) + rpc.api._barrier(name_list) + mq.stop() + + +def torchrpc_cuda(rank): + global mq + global recv_tensor_list + mq = None + recv_tensor_list = [None, None, None, None] + name_list = ["A", "B"] + logging.getLogger().setLevel(logging.DEBUG) + + if rank == 0: + attach_to = name_list[1:] + else: + attach_to = None + + peer_rank = int(rank == 0) or 0 + peer_name = name_list[peer_rank] + device_map = DeviceMap(rank, [peer_name], [rank], [peer_rank]) + logging.debug(device_map) + + mq = TORCHRPCMQ( + rpc_name=name_list[rank], + global_rank=rank, + init_method="tcp://127.0.0.1:12390", + remote_parallel_entrance=remote_mq_entrance, + attach_to=attach_to, + device_maps=device_map, + async_rpc=False, + cuda_device=rank, + use_cuda=True + ) + + def fn1(tensor: torch.Tensor) -> None: + global recv_tensor_list + global mq + recv_tensor_list[0] = tensor + assert recv_tensor_list[0].sum().item() == 777 + assert recv_tensor_list[0].device == torch.device(1) + + mq.subscribe(topic="RANK_0_SEND", fn=fn1) + mq.listen() + + if rank == 0: + mq.publish("RANK_0_SEND", torch.ones(777).cuda(0)) + + mq._rendezvous_until_world_size(2) + all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() + rpc.api._barrier([worker.name for worker in all_worker_info]) + mq.stop() + + +def torchrpc_args_parser(rank): + global mq + global recv_tensor_list + from ding.framework.parallel import Parallel + logging.getLogger().setLevel(logging.DEBUG) + + params = Parallel._torchrpc_args_parser( + n_parallel_workers=1, + attach_to=[], + node_ids=[0], + init_method="tcp://127.0.0.1:12399", + use_cuda=True, + local_cuda_devices=None, + cuda_device_map=None + )[0] + + logging.debug(params) + + # 1. If attach_to is empty, init_rpc will not block. + mq = TORCHRPCMQ(**params) + mq.listen() + assert mq._running + mq.stop() + assert not mq._running + logging.debug("[Pass] 1. If attach_to is empty, init_rpc will not block.") + + # 2. n_parallel_workers != len(node_ids) + try: + Parallel._torchrpc_args_parser(n_parallel_workers=999, attach_to=[], node_ids=[1, 2])[0] + except RuntimeError as e: + logging.debug("[Pass] 2. n_parallel_workers != len(node_ids).") + else: + assert False + + # 3. len(local_cuda_devices) != n_parallel_workers + try: + Parallel._torchrpc_args_parser(n_parallel_workers=8, node_ids=[1], local_cuda_devices=[1, 2, 3])[0] + except RuntimeError as e: + logging.debug("[Pass] 3. len(local_cuda_devices) != n_parallel_workers.") + else: + assert False + + # 4. n_parallel_workers > gpu_nums + # TODO(wgt): Support spwan mode to start torchrpc process using CPU/CUDA and CPU only. + try: + Parallel._torchrpc_args_parser(n_parallel_workers=999, node_ids=[1], use_cuda=True)[0] + except RuntimeError as e: + logging.debug("[Pass] 4. n_parallel_workers > gpu_nums.") + else: + assert False + + # 5. Set custom device map. + params = Parallel._torchrpc_args_parser( + n_parallel_workers=1, node_ids=[1], cuda_device_map=["0_0_0", "0_1_2", "1_1_4"] + )[0] + assert params['device_maps'].peer_name_list == ["Node_0", "Node_0", "Node_1"] + assert params['device_maps'].our_device_list == [0, 1, 1] + assert params['device_maps'].peer_device_list == [0, 2, 4] + # logging.debug(params['device_maps']) + logging.debug("[Pass] 5. Set custom device map.") + + # 6. Set n_parallel_workers > 1 + params = Parallel._torchrpc_args_parser(n_parallel_workers=8, node_ids=[1]) + assert len(params) == 8 + assert params[7]['node_id'] == 8 + assert params[0]['use_cuda'] is False + assert params[0]['device_maps'] is None + assert params[0]['cuda_device'] is None + + if torch.cuda.device_count() >= 2: + params = Parallel._torchrpc_args_parser(n_parallel_workers=2, node_ids=[1], use_cuda=True) + assert params[0]['use_cuda'] + assert len(params[0]['device_maps'].peer_name_list) == DEFAULT_DEVICE_MAP_NUMS - 1 + logging.debug("[Pass] 6. Set n_parallel_workers > 1.") + + +@pytest.mark.multiprocesstest +def test_torchrpc(): + ctx = get_context("spawn") + if platform.system().lower() != 'windows' and torch_ge_1121(): + with ctx.Pool(processes=4) as pool: + pool.map(torchrpc, range(4)) + + +@pytest.mark.cudatest +@pytest.mark.multiprocesstest +def test_torchrpc_cuda(): + if platform.system().lower() != 'windows': + if torch_ge_1121() and torch.cuda.is_available() and torch.cuda.device_count() >= 2: + ctx = get_context("spawn") + with ctx.Pool(processes=2) as pool: + pool.map(torchrpc_cuda, range(2)) + + +@pytest.mark.cudatest +@pytest.mark.multiprocesstest +def test_torchrpc_parser(): + if platform.system().lower() != 'windows' and torch_ge_1121() and torch.cuda.is_available(): + ctx = get_context("spawn") + with ctx.Pool(processes=1) as pool: + pool.map(torchrpc_args_parser, range(1)) diff --git a/ding/framework/message_queue/torch_rpc.py b/ding/framework/message_queue/torch_rpc.py new file mode 100644 index 0000000000..cee70c8bfb --- /dev/null +++ b/ding/framework/message_queue/torch_rpc.py @@ -0,0 +1,391 @@ +from ding.framework.message_queue.mq import MQ +from ding.utils import MQ_REGISTRY +from ditk import logging +from ding.utils import LockContext, LockContextType + +from typing import List, Optional, Tuple, Dict, Any, Union, Callable +from threading import Thread +from enum import Enum + +from torch.distributed import rpc + +import os +import time +import queue +import torch +import platform + +if platform.system().lower() != 'windows': + from torch.distributed.rpc import TensorPipeRpcBackendOptions + +DEFAULT_DEVICE_MAP_NUMS = 12 + + +# About RPCEvent: +# RPCEvent stores events that are not related to RL train logic. +# Private events use "int" to represent topic in order to reduce overhead, because the +# order and content of these events are hard-coded. The user-defined topic is uniquely +# identified by a string, because we cannot guarantee the order in which each process +# registers the same topic. +# +# There are four types of private events: +# 1. "CLINET_REGISTER_STUB": Responsible for the connect. +# 2. "CUSTOM_FUNCRION_RPC": Responsible for RPC which using provided RPC methods +# The remote function must be given with the positional parameter "custom_method". +# "custom_method" must be picklable, otherwise use subscribe() to register topic +# and corresponding method on the client side in advance. +# 3. "NOTIFY_SHUTDOWN": Responsible for the disconnect info from other process. +# 4. "REQUIRE_SHUTDOWN": Responsible for the disconnect request which was asked for. +class RPCEvent(int, Enum): + CLINET_REGISTER_STUB = 1 + CUSTOM_FUNCRION_RPC = 2 + NOTIFY_SHUTDOWN = 3 + REQUIRE_SHUTDOWN = 4 + + +class DeviceMap: + + def __init__( + self, + our_name: str, + peer_name_list: List[str] = None, + our_device_list: List[int] = None, + peer_device_list: List[int] = None + ) -> None: + """ + Overview: + Mapping management for gpu devices. + Arguments: + - peer_name_list (List[str], optional): remote processes unique rpc name. + - our_device_list (List[int], optional): local processes device rank. + - peer_device_list (List[int], optional): remote processes device rank. + """ + + self.peer_name_list = peer_name_list or [] + self.our_device_list = our_device_list or [] + self.peer_device_list = peer_device_list or [] + + assert len(self.peer_name_list) == len(self.peer_name_list) + assert len(self.peer_device_list) == len(self.peer_device_list) + + self.our_name = str(our_name) + + def __str__(self): + info = "" + for i in range(len(self.peer_name_list)): + info += "{} : GPU-{} --> {} : GPU-{};{}".format( + self.our_name, str(self.our_device_list[i]), str(self.peer_name_list[i]), str(self.peer_device_list[i]), + "\n" if i != len(self.peer_name_list) - 1 else "" + ) + return info + + def set_device(self, option) -> None: + """ + Overview: + Initialize TensorPipeRpcBackendOptions according to the GPU mapping + set by the user. + Arguments: + - option (class TensorPipeRpcBackendOptions) + """ + for i in range(len(self.peer_name_list)): + option.set_device_map(self.peer_name_list[i], {self.our_device_list[i]: self.peer_device_list[i]}) + + +@MQ_REGISTRY.register("torchrpc") +class TORCHRPCMQ(MQ): + + def __init__( + self, + rpc_name: str, + init_method: str, + remote_parallel_entrance: Callable, + global_rank: int = 0, + attach_to: Optional[List[str]] = None, + device_maps: Optional[DeviceMap] = None, + async_rpc: Optional[bool] = True, + async_backend_polling: Optional[bool] = False, + use_cuda: Optional[bool] = False, + cuda_device: Optional[int] = None, + channels: Optional[List[str]] = None, + **kwargs + ) -> None: + """ + Overview: + Connect distributed processes with torch.distributed.rpc + Arguments: + - rpc_name (str): Globally unique name for rpc + - init_method (str): URL specifying how to initialize the process group. + - remote_parallel_entrance (Callable): Get the entry function of the remote Parallel() + struct. This function must ensure that the remote method call can find the corresponding + TORCHRPCMQ struct locally. + - attach_to (Optional[List[str]], optional): The ranks want to connect to, comma-separated ranks. + - global_rank (int, optional): Globally unique id. + - device_maps (DeviceMap, optional): Used for torch rpc init device_maps. + - async_rpc (Optional[bool]): Whether to use asynchronous rpc, the default is false. + - async_backend_polling (Optional[bool]): Whether to enable background threads to poll future objects + generated by asynchronous RPCs. + - use_cuda (Optional[bool]): Whether there will be data on the GPU side involved in the communication, + if true, torchrpc will set the device map. + - cuda_device (Optional[int]): An optional list of local devices, the default is all visible devices. + - channels (Optional[List[str]]): Channels contain the communication methods used by tensorpipe when + transmitting tensor, including the following possible values: "basic", "cma", "mpt_uv", "cuda_ipc", + "cuda_gdr", "cuda_xth". + """ + self.name = rpc_name + self.global_rank = global_rank + + self._running = False + self.remote_parallel_entrance = remote_parallel_entrance + + self._peer_set = set(attach_to if attach_to else []) + self._peer_set_lock = LockContext(type_=LockContextType.THREAD_LOCK) + + if platform.system().lower() != 'windows': + self.rpc_backend_options = TensorPipeRpcBackendOptions( + num_worker_threads=16, rpc_timeout=30, init_method=init_method, _channels=channels + ) + else: + raise WindowsError("TensorPipe does not support Windows yet!") + + if use_cuda: + assert torch.cuda.is_available() + assert device_maps + assert cuda_device is not None + + self._device_maps = device_maps + self._device_maps.set_device(self.rpc_backend_options) + self.rpc_backend_options.set_devices([cuda_device]) + else: + self._device_maps = None + + self._rpc_events = { + RPCEvent.CLINET_REGISTER_STUB: self.accept_rpc_connect, + RPCEvent.CUSTOM_FUNCRION_RPC: self.call_custom_rpc_method, + RPCEvent.NOTIFY_SHUTDOWN: self.notify_shutdown, + RPCEvent.REQUIRE_SHUTDOWN: self.stop + } + + self._async = async_rpc + self._async_backend_polling = async_rpc and async_backend_polling + if self._async_backend_polling: + self.async_future_queue = queue.Queue() + # Using threads to poll performance suffers due to the presence of Python GIL locks. + self.polling_thread = Thread(target=self._backend_polling, name="backend_polling", daemon=True) + + logging.debug( + "Torchrpc info: process name:\"{}\", node_id:[{}], attach_to[{}], init_method:{}.".format( + self.name, self.global_rank, self._peer_set, init_method + ) + ) + + def show_device_maps(self): + if self._device_maps: + logging.info("{}".format(self._device_maps)) + else: + logging.info("Not set device map!") + + def subscribe(self, topic: Union[int, str], fn: Optional[Callable] = None, is_once: Optional[bool] = False) -> None: + if fn is None: + raise RuntimeError("The Torchrpc subscription topic must be provided with a callback function.") + if topic not in self._rpc_events: + + def once_callback(*args, **kwargs): + fn(*args, **kwargs) + self.unsubscribe(topic) + + self._rpc_events[topic] = fn if not is_once else once_callback + + def unsubscribe(self, topic: Union[int, str]) -> None: + if topic in self._rpc_events: + self._rpc_events.pop(topic) + + def rpc_event_router(self, topic: Union[int, str], *args, **kwargs) -> Any: + """ + Overview: + Entry function called after all remote methods reach the target process. + Arguments: + - topic (Union[int, str]): Recevied topic. + """ + if topic not in self._rpc_events: + logging.warning("{} Torchrpc topic \"{}\" is not registered.".format(self.name, topic)) + return + + return (self._rpc_events[topic])(*args, **kwargs) + + def listen(self) -> None: + # If device_map is not specified, init_rpc will block until all processes + # smaller than the current rank call init_rpc. If device_map is specified, + # then init_rpc blocks until all processes present in device_map call init_rpc. + rpc.init_rpc(name=self.name, rank=self.global_rank, rpc_backend_options=self.rpc_backend_options) + + # Wait for all processes rendezvous before starting subsequent steps + for i, peer in enumerate(self._peer_set): + while True: + try: + self._do_rpc(peer, RPCEvent.CLINET_REGISTER_STUB, self.name, self.global_rank) + except Exception as e: + logging.debug( + "\"{}\" try to rendezvous with \"{}\" error, because \"{}\"".format(self.name, peer, e) + ) + time.sleep(0.5) + continue + else: + logging.debug("\"{}\" irendezvous with \"{}\" success!".format(self.name, peer)) + break + + if self._async_backend_polling: + self.polling_thread.start() + self._running = True + + logging.debug("\"{}\" Torchrpc backend init success.".format(self.name)) + + def publish(self, topic: Union[int, str], *args, **kwargs) -> Any: + if self._running: + timeout_list = [] + + if len(self._peer_set) == 0: + logging.warning("No peer available to communicate with") + return + + with self._peer_set_lock: + for peer in self._peer_set: + if not self._running: + break + try: + re = self._do_rpc(peer, topic, *args, **kwargs) + except RuntimeError as e: + logging.error("Publish topic \"{}\" to peer \"{}\" has error: \"{}\"!".format(topic, peer, e)) + timeout_list.append(peer) + + for timeout_peer in timeout_list: + self._peer_set.remove(timeout_peer) + + return re + + def accept_rpc_connect(self, peer_name: str, peer_rank: int) -> None: + """ + Overview: + Receive the link signal sent by the peer. + Arguments: + - peer_name (str) + - peer_rank (int): + """ + with self._peer_set_lock: + if peer_name not in self._peer_set: + self._peer_set.add(peer_name) + + return + + def call_custom_rpc_method(self, *args, **kwargs) -> Any: + """ + Overview: + If the upper-level module wants to pass in a custom rpc method, + it will be called remotly by this function. + """ + fn = kwargs.pop('custom_method') + return fn(*args, **kwargs) + + def notify_shutdown(self, peer_name: str, *args, **kwargs) -> None: + """ + Overview: + Receive the exit signal sent by the peer. + Arguments: + - peer_name (str) + """ + with self._peer_set_lock: + if peer_name in self._peer_set: + logging.info("\"{}\" recv shutdown info from \"{}\".".format(self.name, peer_name)) + self._peer_set.remove(peer_name) + + def recv(self): + raise NotImplementedError + + def stop(self) -> None: + if self._running: + with self._peer_set_lock: + for peer in self._peer_set: + try: + self._do_rpc(peer, RPCEvent.NOTIFY_SHUTDOWN, self.name) + except RuntimeError as e: + continue + + if self._async_backend_polling: + while self.async_future_queue.qsize() > 0: + time.sleep(0.05) + continue + + self.polling_thread.join(timeout=1) + self.polling_thread = None + + self._running = False + + # Set graceful=False, we do not wait for other RPC processes to reach this method. + rpc.shutdown(graceful=False) + + logging.info("\"{}\" Torchrpc backend is stopped.".format(self.name)) + + def require_to_shutdown(self, peer_name: str): + """ + Overview: + Request the remote torch rpc message queue to be stopped. + Arguments: + - peer_name (str): Remote torch rpc message's name + """ + try: + re = self._do_rpc(peer_name, RPCEvent.REQUIRE_SHUTDOWN) + if self._async and not self._async_backend_polling: + re.wait() + except RuntimeError as e: + logging.warning("Torchrpc polling_thread error: \"{}\".".format(e)) + + def _rendezvous_until_world_size(self, world_size) -> None: + while True: + all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() + if len(all_worker_info) != world_size: + time.sleep(0.5) + else: + break + + def _do_rpc(self, peer: str, topic: Union[int, str] = Optional[None], *arg, **kwargs) -> Union[None, Any]: + """ + Overview: + Where the actual RPC communication takes place + Arguments: + - peer (str): [The rpc name of the peer] + - topic (int): [The topic passed by upstream] + """ + arg = [topic] + list(arg) + + if self._async: + future = rpc.rpc_async(peer, self.remote_parallel_entrance, args=arg, kwargs=kwargs) + if not self._async_backend_polling: + return future + else: + self.async_future_queue.put(future) + return None + else: + return rpc.rpc_sync(peer, self.remote_parallel_entrance, args=arg, kwargs=kwargs) + + def _backend_polling(self) -> None: + while True: + if not self._running: + break + + future = self.async_future_queue.get() + try: + if not future.done(): + time.sleep(0.05) + future.wait() + except RuntimeError as e: + logging.warning("Torchrpc polling thread catch RuntimeError: \"{}\".".format(e)) + + def wait_for_shutdown(self): + """ + Overview: + The thread calling this method will block until mq receives a request for shutdown. + """ + while True: + if not self._running: + break + else: + time.sleep(0.5) diff --git a/ding/framework/middleware/distributer.py b/ding/framework/middleware/distributer.py index c68a4b808f..8f53068138 100644 --- a/ding/framework/middleware/distributer.py +++ b/ding/framework/middleware/distributer.py @@ -2,8 +2,10 @@ from dataclasses import fields from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union from ditk import logging -from ding.framework import task +from ding.framework import task, MQType from ding.data import StorageLoader, Storage, ModelLoader +from ding.utils import LockContext, LockContextType + if TYPE_CHECKING: from ding.framework.context import Context from torch.nn import Module @@ -11,7 +13,11 @@ class ContextExchanger: - def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] = None) -> None: + def __init__( + self, + skip_n_iter: int = 1, + storage_loader: Optional[StorageLoader] = None, + ) -> None: """ Overview: Exchange context between processes, @@ -33,9 +39,16 @@ def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] self._event_name = "context_exchanger_{role}" self._skip_n_iter = skip_n_iter self._storage_loader = storage_loader + + # Both nng and torchrpc use background threads to trigger the receiver's recv action, + # there is a race condition between sender and sender, and between senders and receiver. + self._put_lock = LockContext(LockContextType.THREAD_LOCK) + self._recv_ready = False + self._bypass_eventloop = task.router.mq_type == MQType.RPC + for role in task.role: # Only subscribe to other roles if not task.has_role(role): - task.on(self._event_name.format(role=role), self.put) + task.on(self._event_name.format(role=role), self.put, bypass_eventloop=self._bypass_eventloop) if storage_loader: task.once("finish", lambda _: storage_loader.shutdown()) @@ -62,7 +75,12 @@ def __call__(self, ctx: "Context"): if self._storage_loader and task.has_role(task.role.COLLECTOR): payload = self._storage_loader.save(payload) for role in task.roles: - task.emit(self._event_name.format(role=role), payload, only_remote=True) + task.emit( + self._event_name.format(role=role), + payload, + only_remote=True, + bypass_eventloop=self._bypass_eventloop + ) def __del__(self): if self._storage_loader: @@ -76,12 +94,14 @@ def put(self, payload: Union[Dict, Storage]): """ def callback(payload: Dict): - for key, item in payload.items(): - fn_name = "_put_{}".format(key) - if hasattr(self, fn_name): - getattr(self, fn_name)(item) - else: - logging.warning("Receive unexpected key ({}) in context exchanger".format(key)) + with self._put_lock: + for key, item in payload.items(): + fn_name = "_put_{}".format(key) + if hasattr(self, fn_name): + getattr(self, fn_name)(item) + else: + logging.warning("Receive unexpected key ({}) in context exchanger".format(key)) + self._recv_ready = True if isinstance(payload, Storage): assert self._storage_loader is not None, "Storage loader is not defined when data is a storage object." @@ -106,26 +126,29 @@ def fetch(self, ctx: "Context") -> Dict[str, Any]: return payload def merge(self, ctx: "Context"): + if task.has_role(task.role.LEARNER): # Learner should always wait for trajs. # TODO: Automaticlly wait based on properties, not roles. - while len(self._state) == 0: + while self._recv_ready is False: sleep(0.01) elif ctx.total_step >= self._skip_n_iter: start = time() - while len(self._state) == 0: + while self._recv_ready is False: if time() - start > 60: logging.warning("Timeout when waiting for new context! Node id: {}".format(task.router.node_id)) break sleep(0.01) - for k, v in self._state.items(): - if not task.has_role(task.role.COLLECTOR) and k.startswith('increment_'): - pure_k = k.split('increment_')[-1] - setattr(ctx, pure_k, getattr(ctx, pure_k) + v) - else: - setattr(ctx, k, v) - self._state = {} + with self._put_lock: + for k, v in self._state.items(): + if not task.has_role(task.role.COLLECTOR) and k.startswith('increment_'): + pure_k = k.split('increment_')[-1] + setattr(ctx, pure_k, getattr(ctx, pure_k) + v) + else: + setattr(ctx, k, v) + self._state = {} + self._recv_ready = False # Handle each attibute of context def _put_trajectories(self, traj: List[Any]): @@ -150,14 +173,14 @@ def _fetch_episodes(self, episodes: List[Any]): if task.has_role(task.role.COLLECTOR): return episodes - def _put_trajectory_end_idx(self, trajectory_end_idx: List[str]): + def _put_trajectory_end_idx(self, trajectory_end_idx: List[int]): if not task.has_role(task.role.LEARNER): return if "trajectory_end_idx" not in self._state: self._state["trajectory_end_idx"] = [] self._state["trajectory_end_idx"].extend(trajectory_end_idx) - def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[str]): + def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[int]): if task.has_role(task.role.COLLECTOR): return trajectory_end_idx @@ -179,12 +202,6 @@ def _put_env_episode(self, increment_env_episode: int): self._state['increment_env_episode'] = 0 self._state["increment_env_episode"] += increment_env_episode - def _fetch_env_episode(self, env_episode: int): - if task.has_role(task.role.COLLECTOR): - increment_env_episode = env_episode - self._local_state['env_episode'] - self._local_state['env_episode'] = env_episode - return increment_env_episode - def _put_train_iter(self, train_iter: int): if not task.has_role(task.role.LEARNER): self._state["train_iter"] = train_iter @@ -211,8 +228,9 @@ def __init__(self, model: "Module", model_loader: Optional[ModelLoader] = None) self._event_name = "model_exchanger" self._state_dict_cache: Optional[Union[object, Storage]] = None self._is_learner = task.has_role(task.role.LEARNER) + self._bypass_eventloop = task.router.mq_type == MQType.RPC if not self._is_learner: - task.on(self._event_name, self._cache_state_dict) + task.on(self._event_name, self._cache_state_dict, bypass_eventloop=self._bypass_eventloop) if model_loader: task.once("finish", lambda _: model_loader.shutdown()) @@ -278,11 +296,13 @@ def _send_model(self): if self._model_loader: self._model_loader.save(self._send_callback) else: - task.emit(self._event_name, self._model.state_dict(), only_remote=True) + task.emit( + self._event_name, self._model.state_dict(), only_remote=True, bypass_eventloop=self._bypass_eventloop + ) def _send_callback(self, storage: Storage): if task.running: - task.emit(self._event_name, storage, only_remote=True) + task.emit(self._event_name, storage, only_remote=True, bypass_eventloop=self._bypass_eventloop) def __del__(self): if self._model_loader: diff --git a/ding/framework/middleware/functional/collector.py b/ding/framework/middleware/functional/collector.py index 20820d7d00..16930db826 100644 --- a/ding/framework/middleware/functional/collector.py +++ b/ding/framework/middleware/functional/collector.py @@ -5,6 +5,7 @@ from ding.envs import BaseEnvManager from ding.policy import Policy from ding.torch_utils import to_ndarray, get_shape0 +from ding.torch_utils import to_device if TYPE_CHECKING: from ding.framework import OnlineRLContext @@ -98,6 +99,10 @@ def rolloutor(policy: Policy, env: BaseEnvManager, transitions: TransitionList) env_episode_id = [_ for _ in range(env.env_num)] current_id = env.env_num + use_cuda_shared_memory = False + + if hasattr(cfg, "env") and hasattr(cfg.env, "manager"): + use_cuda_shared_memory = cfg.env.manager.cuda_shared_memory def _rollout(ctx: "OnlineRLContext"): """ @@ -113,16 +118,30 @@ def _rollout(ctx: "OnlineRLContext"): trajectory stops. """ - nonlocal current_id + nonlocal current_id, use_cuda_shared_memory timesteps = env.step(ctx.action) ctx.env_step += len(timesteps) - timesteps = [t.tensor() for t in timesteps] + + if not use_cuda_shared_memory: + timesteps = [t.tensor() for t in timesteps] + # TODO abnormal env step for i, timestep in enumerate(timesteps): transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep) transition = ttorch.as_tensor(transition) # TBD transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]]) + + # torchrpc currently uses "cuda:0" as the transmission device by default, + # so all data on the cpu side is copied to "cuda:0" here. In fact this + # copy is unnecessary, because torchrpc can support both cpu side and gpu + # side data to communicate using RDMA, but mixing the two transfer types + # will cause a bug, see issue: + # Because we have copied the large payload "obs" and "next_obs" from the + # collector's subprocess to "cuda:0" in advance, the copy operation here + # will not have too much overhead. + if use_cuda_shared_memory: + transition = to_device(transition, "cuda:0") transitions.append(timestep.env_id, transition) if timestep.done: policy.reset([timestep.env_id]) diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index 38e343e495..e8369a4476 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -8,20 +8,31 @@ from ditk import logging import tempfile import socket +import enum from os import path -from typing import Callable, Dict, List, Optional, Tuple, Union, Set +from typing import Callable, Dict, List, Optional, Tuple, Union, Set, Any from threading import Thread from ding.framework.event_loop import EventLoop from ding.utils.design_helper import SingletonMetaclass from ding.framework.message_queue import * from ding.utils.registry_factory import MQ_REGISTRY +from easydict import EasyDict +from ding.framework.message_queue.torch_rpc import DeviceMap, DEFAULT_DEVICE_MAP_NUMS # Avoid ipc address conflict, random should always use random seed random = random.Random() +class MQType(int, enum.Enum): + NNG = 0 + REDIS = 1 + RPC = 2 + + class Parallel(metaclass=SingletonMetaclass): + _MQtype_dict = {"nng": MQType.NNG, "redis": MQType.REDIS, "torchrpc": MQType.RPC} + def __init__(self) -> None: # Init will only be called once in a process self._listener = None @@ -29,7 +40,6 @@ def __init__(self) -> None: self.node_id = None self.local_id = None self.labels = set() - self._event_loop = EventLoop("parallel_{}".format(id(self))) self._retries = 0 # Retries in auto recovery def _run( @@ -52,9 +62,18 @@ def _run( self.auto_recover = auto_recover self.max_retries = max_retries self._mq = MQ_REGISTRY.get(mq_type)(**kwargs) + self.mq_type = self._MQtype_dict[mq_type] + + if self.mq_type != MQType.RPC: + self._event_loop = EventLoop("parallel_{}".format(id(self))) + time.sleep(self.local_id * self.startup_interval) - self._listener = Thread(target=self.listen, name="mq_listener", daemon=True) - self._listener.start() + if self.mq_type == MQType.RPC: + self._mq.listen() + self.rpc_name = self._mq.name + else: + self._listener = Thread(target=self.listen, name="mq_listener", daemon=True) + self._listener.start() @classmethod def runner( @@ -72,7 +91,11 @@ def runner( max_retries: int = float("inf"), redis_host: Optional[str] = None, redis_port: Optional[int] = None, - startup_interval: int = 1 + init_method: Optional[str] = "env://", + startup_interval: int = 1, + use_cuda: Optional[bool] = False, + local_cuda_devices: Optional[List[str]] = None, + cuda_device_map: Optional[List[str]] = None ) -> Callable: """ Overview: @@ -100,7 +123,11 @@ def runner( """ all_args = locals() del all_args["cls"] - args_parsers = {"nng": cls._nng_args_parser, "redis": cls._redis_args_parser} + args_parsers = { + MQType.NNG: cls._nng_args_parser, + MQType.REDIS: cls._redis_args_parser, + MQType.RPC: cls._torchrpc_args_parser + } assert n_parallel_workers > 0, "Parallel worker number should bigger than 0" @@ -111,7 +138,7 @@ def _runner(main_process: Callable, *args, **kwargs) -> None: Arguments: - main_process (:obj:`Callable`): The main function, your program start from here. """ - runner_params = args_parsers[mq_type](**all_args) + runner_params = args_parsers[cls._MQtype_dict[mq_type]](**all_args) params_group = [] for i, runner_kwargs in enumerate(runner_params): runner_kwargs["local_id"] = i @@ -297,8 +324,13 @@ def on(self, event: str, fn: Callable) -> None: - fn (:obj:`Callable`): Function body. """ if self.is_active: - self._mq.subscribe(event) - self._event_loop.on(event, fn) + if self.mq_type == MQType.RPC: + self._mq.subscribe(event, fn) + else: + self._mq.subscribe(event) + + if hasattr(self, "_event_loop"): + self._event_loop.on(event, fn) def once(self, event: str, fn: Callable) -> None: """ @@ -310,8 +342,13 @@ def once(self, event: str, fn: Callable) -> None: - fn (:obj:`Callable`): Function body. """ if self.is_active: - self._mq.subscribe(event) - self._event_loop.once(event, fn) + if self.mq_type == MQType.RPC: + self._mq.subscribe(event, fn, True) + else: + self._mq.subscribe(event) + + if hasattr(self, "_event_loop"): + self._event_loop.once(event, fn) def off(self, event: str) -> None: """ @@ -322,7 +359,9 @@ def off(self, event: str) -> None: """ if self.is_active: self._mq.unsubscribe(event) - self._event_loop.off(event) + + if hasattr(self, "_event_loop"): + self._event_loop.off(event) def emit(self, event: str, *args, **kwargs) -> None: """ @@ -332,13 +371,16 @@ def emit(self, event: str, *args, **kwargs) -> None: - event (:obj:`str`): Event name. """ if self.is_active: - payload = {"a": args, "k": kwargs} - try: - data = pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL) - except AttributeError as e: - logging.error("Arguments are not pickable! Event: {}, Args: {}".format(event, args)) - raise e - self._mq.publish(event, data) + if self.mq_type == MQType.RPC: + self._mq.publish(event, *args, **kwargs) + else: + payload = {"a": args, "k": kwargs} + try: + data = pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL) + except AttributeError as e: + logging.error("Arguments are not pickable! Event: {}, Args: {}".format(event, args)) + raise e + self._mq.publish(event, data) def _handle_message(self, topic: str, msg: bytes) -> None: """ @@ -349,6 +391,8 @@ def _handle_message(self, topic: str, msg: bytes) -> None: - msg (:obj:`bytes`): Recevied message. """ event = topic + assert hasattr(self, "_event_loop") and self._event_loop + if not self._event_loop.listened(event): logging.debug("Event {} was not listened in parallel {}".format(event, self.node_id)) return @@ -382,10 +426,182 @@ def stop(self): logging.info("Stopping parallel worker on node: {}".format(self.node_id)) self.is_active = False time.sleep(0.03) - if self._mq: + if hasattr(self, "_mq") and self._mq: self._mq.stop() self._mq = None if self._listener: self._listener.join(timeout=1) self._listener = None - self._event_loop.stop() + if hasattr(self, "_event_loop"): + self._event_loop.stop() + + @classmethod + def make_device_maps(cls, + self_id: int, + local_cuda_device: int, + cuda_device_map: Optional[List[str]] = None) -> List[DeviceMap]: + dmap = DeviceMap(f"Node_{self_id}") + if cuda_device_map: + # If the user gave a custom device map, use it. + for item in cuda_device_map: + remote_node_id, local_device_rank, remote_device_rank = item.split("_") + dmap.peer_name_list.append(f"Node_{remote_node_id}") + dmap.our_device_list.append(int(local_device_rank)) + dmap.peer_device_list.append(int(remote_device_rank)) + else: + assert self_id < DEFAULT_DEVICE_MAP_NUMS + # If the user does not provide deivce_map and specifies the use of GPU, we default + # each process to use GPU:0 for communication. This is a convenient approach in a + # container environment. + for i in range(DEFAULT_DEVICE_MAP_NUMS): + if i == self_id: + continue + dmap.peer_name_list.append(f"Node_{i}") + dmap.our_device_list.append(local_cuda_device) + dmap.peer_device_list.append(0) + + return dmap + + @classmethod + def _torchrpc_args_parser( + cls, + n_parallel_workers: int, + attach_to: Optional[List[str]] = None, + node_ids: Optional[Union[List[int], int]] = None, + init_method: Optional[str] = "env://", + use_cuda: Optional[bool] = False, + local_cuda_devices: Optional[List[str]] = None, + cuda_device_map: Optional[List[str]] = None, + remote_parallel_entrance: Optional[Callable] = None, + async_rpc: Optional[bool] = True, + async_backend_polling: Optional[bool] = False, + channels: Optional[List[str]] = None, + **kwargs + ) -> List[Dict[str, dict]]: + import torch + assert init_method + + attach_to = attach_to or [] + node_divice_dict = dict() + + if local_cuda_devices or cuda_device_map: + use_cuda = True + if local_cuda_devices and not cuda_device_map: + logging.warning( + '''If you set local_cuda_devices but not cuda_device_map, torchrpc will use the default + device mapping to map all local GPU devices to the peer GPU-0.''' + ) + + # From the unique identification of each process when using torchrpc to communicate. + local_process_ids = cls.padding_param(node_ids, n_parallel_workers, 0) + attach_to = [f"Node_{id}" for id in attach_to] + nodes = ["Node_{}".format(id) for id in local_process_ids] + + try: + # torchrpc uses "node_id" as global rank, perform necessary checks here. + assert local_process_ids + assert len(local_process_ids) == n_parallel_workers + assert len(set(local_process_ids)) == n_parallel_workers + except AssertionError as e: + raise RuntimeError( + '''Arg "node_ids" must be specified. Please set the number of "node_ids" to be the same as + "n_parallel_workers" (Hint: "node_id" is the unique identifier between processes)''' + ) + + if use_cuda: + assert torch.cuda.is_available() + if local_cuda_devices: + if len(local_cuda_devices) != n_parallel_workers: + raise RuntimeError( + "The length of the \"local_cuda_devices\":[\"{}\"] is != \"n_parallel_workers\":[\"{}\"]". + format(len(local_cuda_devices), n_parallel_workers) + ) + local_cuda_devices = [int(i) for i in local_cuda_devices] + else: + gpu_nums = torch.cuda.device_count() + if n_parallel_workers > gpu_nums: + raise RuntimeError( + "The number of available GPUS [\"{}\"] is less than n_parallel_workers[\"{}\"]".format( + gpu_nums, n_parallel_workers + ) + ) + local_cuda_devices = cls.padding_param(0, n_parallel_workers, 0) + + dmap_lists = [ + cls.make_device_maps(node_id, local_cuda_devices[i], cuda_device_map) + for i, node_id in enumerate(local_process_ids) + ] + else: + local_cuda_devices = [None for i in range(n_parallel_workers)] + dmap_lists = [None for i in range(n_parallel_workers)] + + if channels: + list_channels = [channels for i in range(n_parallel_workers)] + else: + list_channels = [None for i in range(n_parallel_workers)] + + global local_parallel_entrance + entrance_fn = remote_parallel_entrance if remote_parallel_entrance else local_parallel_entrance + runner_params = [] + for i in range(n_parallel_workers): + runner_kwargs = { + **kwargs, "node_id": local_process_ids[i], + "n_parallel_workers": n_parallel_workers, + "rpc_name": nodes[i], + "global_rank": local_process_ids[i], + "init_method": init_method, + "remote_parallel_entrance": entrance_fn, + "attach_to": attach_to, + "device_maps": dmap_lists[i], + "cuda_device": local_cuda_devices[i], + "use_cuda": use_cuda, + "async_rpc": async_rpc, + "async_backend_polling": async_backend_polling, + "channels": list_channels[i] + } + runner_params.append(runner_kwargs) + + return runner_params + + def get_mq(self): + return self._mq + + def judge_use_cuda_shm(self, cfg: EasyDict) -> None: + """ + Overview: + Only when torchrpc is used and env uses shared memory, cuda tensor + is used as the communication method between env subprocesses and + collector process. + Arguments: + - cfg (:obj:`EasyDict`): Input config dict which is to be used in the following pipeline. + """ + if not hasattr(cfg, "env") or not hasattr(cfg.env, "manager"): + return + + if cfg.env.manager.shared_memory: + if self.mq_type == MQType.RPC and "collector" in self.labels: + cfg.env.manager.cuda_shared_memory = True + return + cfg.env.manager.cuda_shared_memory = False + return + + +def local_parallel_entrance(topic: Union[int, str], *args, **kwargs) -> Any: + """ + Overview: + We must provide a method for all RPC methods to obtain the data structure + instantiated in the remote process. Because we don't want to and can't pickle + data structures such as Task() or Parallel(). + + Unlike nng, torchrpc needs to consider thread safety. Class 'Parallel' is a singleton + class. At this moment, Parallel() must have been instantiated, because + 'accept_rpc_connect'will only be executed after local-side init_rpc has completed, + + This function must be picklable, so should not be a local function. + + This function will be called concurrently by multiple threads, and the provider of + the RPC method needs to ensure that its own RPC method is thread-safe. + Arguments: + - topic (Union[int, str]): Recevied topic. + """ + return Parallel().get_mq().rpc_event_router(topic, *args, **kwargs) diff --git a/ding/framework/task.py b/ding/framework/task.py index ae6e0e256d..131349bb40 100644 --- a/ding/framework/task.py +++ b/ding/framework/task.py @@ -13,7 +13,7 @@ import inspect from ding.framework.context import Context -from ding.framework.parallel import Parallel +from ding.framework.parallel import Parallel, MQType from ding.framework.event_loop import EventLoop from functools import wraps @@ -201,6 +201,7 @@ def run(self, max_step: int = int(1e12)) -> None: assert self._running, "Please make sure the task is running before calling the this method, see the task.start" if len(self._middleware) == 0: return + start_time = 0 for i in range(max_step): for fn in self._middleware: self.forward(fn) @@ -215,6 +216,11 @@ def run(self, max_step: int = int(1e12)) -> None: break self.renew() + if i == 0: + # Skip the first round of timing + start_time = time.time() + return start_time + def wrap(self, fn: Callable, lock: Union[bool, Lock] = False) -> Callable: """ Overview: @@ -424,7 +430,15 @@ def async_executor(self, fn: Callable, *args, **kwargs) -> None: t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs) self._async_stack.append(t) - def emit(self, event: str, *args, only_remote: bool = False, only_local: bool = False, **kwargs) -> None: + def emit( + self, + event: str, + *args, + only_remote: bool = False, + only_local: bool = False, + bypass_eventloop: bool = False, + **kwargs + ) -> None: """ Overview: Emit an event, call listeners. @@ -432,43 +446,66 @@ def emit(self, event: str, *args, only_remote: bool = False, only_local: bool = - event (:obj:`str`): Event name. - only_remote (:obj:`bool`): Only broadcast the event to the connected nodes, default is False. - only_local (:obj:`bool`): Only emit local event, default is False. + - bypass_eventloop (:obj:`bool`): Whether to select to bypass eventloop of Task() and Parallel(), + this parameter can only be True when torchrpc is used as the communication backend. If use torchrpc, + the invoked of the callback is triggered by the torchrpc's backend thread. - args (:obj:`any`): Rest arguments for listeners. """ # Check if need to broadcast event to connected nodes, default is True assert self._running, "Please make sure the task is running before calling the this method, see the task.start" - if only_local: - self._event_loop.emit(event, *args, **kwargs) - elif only_remote: + if bypass_eventloop: if self.router.is_active: - self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs) + self.router.emit(self._wrap_event_name(event), *args, **kwargs) else: - if self.router.is_active: - self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs) + if only_local: + self._event_loop.emit(event, *args, **kwargs) + elif only_remote: + if self.router.is_active: + self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs) + else: + if self.router.is_active: + self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs) self._event_loop.emit(event, *args, **kwargs) - def on(self, event: str, fn: Callable) -> None: + def on(self, event: str, fn: Callable, bypass_eventloop: Optional[bool] = False) -> None: """ Overview: Subscribe to an event, execute this function every time the event is emitted. Arguments: - event (:obj:`str`): Event name. - fn (:obj:`Callable`): The function. + - bypass_eventloop (:obj:`bool`): Same as the bypass_eventloop arg in Task.emit. """ - self._event_loop.on(event, fn) - if self.router.is_active: - self.router.on(self._wrap_event_name(event), self._event_loop.emit) + if bypass_eventloop: + if self.router.mq_type == MQType.RPC: + if self.router.is_active: + self.router.on(self._wrap_event_name(event), fn) + else: + raise RuntimeError("Only message queue implemented by torchrpc allows bypass eventloop") + else: + self._event_loop.on(event, fn) + if self.router.is_active: + self.router.on(self._wrap_event_name(event), self._event_loop.emit) - def once(self, event: str, fn: Callable) -> None: + def once(self, event: str, fn: Callable, bypass_eventloop: Optional[bool] = False) -> None: """ Overview: Subscribe to an event, execute this function only once when the event is emitted. Arguments: - event (:obj:`str`): Event name. - fn (:obj:`Callable`): The function. + - bypass_eventloop (:obj:`bool`): Same as the bypass_eventloop arg in Task.emit. """ - self._event_loop.once(event, fn) - if self.router.is_active: - self.router.on(self._wrap_event_name(event), self._event_loop.emit) + if bypass_eventloop: + if self.router.mq_type == MQType.RPC: + if self.router.is_active: + self.router.once(self._wrap_event_name(event), fn) + else: + raise RuntimeError("Only message queue implemented by torchrpc allows bypass eventloop") + else: + self._event_loop.once(event, fn) + if self.router.is_active: + self.router.on(self._wrap_event_name(event), self._event_loop.emit) def off(self, event: str, fn: Optional[Callable] = None) -> None: """ diff --git a/ding/torch_utils/data_helper.py b/ding/torch_utils/data_helper.py index 8e6d026499..7b098698bd 100644 --- a/ding/torch_utils/data_helper.py +++ b/ding/torch_utils/data_helper.py @@ -62,6 +62,8 @@ def to_device(item: Any, device: str, ignore_keys: list = []) -> Any: return item elif isinstance(item, torch.distributions.Distribution): # for compatibility return item + elif isinstance(item, ttorch.Tensor): + return item.to(device) else: raise TypeError("not support item type: {}".format(type(item))) diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index 5e262b2c38..88f39e0d3d 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -11,7 +11,7 @@ from .k8s_helper import get_operator_server_kwargs, exist_operator_server, DEFAULT_K8S_COLLECTOR_PORT, \ DEFAULT_K8S_LEARNER_PORT, DEFAULT_K8S_AGGREGATOR_SLAVE_PORT, DEFAULT_K8S_COORDINATOR_PORT, pod_exec_command, \ K8sLauncher -from .lock_helper import LockContext, LockContextType, get_file_lock, get_rw_file_lock +from .lock_helper import LockContext, LockContextType, get_file_lock, get_rw_file_lock, synchronized from .log_helper import build_logger, pretty_print, LoggerFactory from .log_writer_helper import DistributedWriter from .orchestrator_launcher import OrchestratorLauncher @@ -36,3 +36,6 @@ else: from .pytorch_ddp_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \ allreduce, broadcast, DistContext, allreduce_async, synchronize + +from .comm_perf_helper import TENSOR_SIZE_LIST, DO_PERF, tensor_size_beauty_print, byte_beauty_print, \ + dtype_2_byte, time_perf_avg, time_perf_once, print_timer_result_csv diff --git a/ding/utils/comm_perf_helper.py b/ding/utils/comm_perf_helper.py new file mode 100644 index 0000000000..416b794c56 --- /dev/null +++ b/ding/utils/comm_perf_helper.py @@ -0,0 +1,145 @@ +import torch +import functools +import time +from concurrent import futures +from ditk import logging +from typing import List, Optional, Tuple, Dict, Any +from ding.utils import EasyTimer + +# Data size for some tests +UNIT_1_B = 1 +UNIT_1_KB = 1024 * UNIT_1_B +UNIT_1_MB = 1024 * UNIT_1_KB +UNIT_1_GB = 1024 * UNIT_1_MB +TENSOR_SIZE_LIST = [ + 8 * UNIT_1_B, 32 * UNIT_1_B, 64 * UNIT_1_B, UNIT_1_KB, 4 * UNIT_1_KB, 64 * UNIT_1_KB, 1 * UNIT_1_MB, 4 * UNIT_1_MB, + 64 * UNIT_1_MB, 512 * UNIT_1_MB, 1 * UNIT_1_GB, 2 * UNIT_1_GB, 4 * UNIT_1_GB +] + +# TODO: Add perf switch to avoid performance loss to critical paths during non-test time. +DO_PERF = False + +# Convert from torch.dtype to bytes +TYPE_MAP = {torch.float32: 4, torch.float64: 8, torch.int32: 4, torch.int64: 8, torch.uint8: 1} + +# A list of time units and names. +TIME_UNIT = [1, 1000, 1000] +TIME_NAME = ["s", "ms", "us"] + +# The global function timing result is stored in OUTPUT_DICT. +OUTPUT_DICT = dict() + + +def _store_timer_result(func_name: str, avg_tt: float): + if func_name not in OUTPUT_DICT.keys(): + OUTPUT_DICT[func_name] = str(round(avg_tt, 4)) + "," + else: + OUTPUT_DICT[func_name] = OUTPUT_DICT[func_name] + str(round(avg_tt, 4)) + "," + + +def print_timer_result_csv(): + """ + Overview: + Output the average execution time of all functions durning this + experiment in csv format. + """ + for key, value in OUTPUT_DICT.items(): + print("{},{}".format(key, value)) + + +def time_perf_once(unit: int, cuda: bool = False): + """ + Overview: + Decorator function to measure the time of a function execution. + Arguments: + - unit ([int]): 0 for s timer, 1 for ms timer, 2 for us timer. + - cuda (bool, optional): Whether CUDA operation occurred within the timing range. + """ + + def decorator(func): + + @functools.wraps(func) + def wrapper(*args, **kw): + timer = EasyTimer(cuda=cuda) + with timer: + func(*args, **kw) + tt = timer.value * TIME_UNIT[unit] + logging.info("func:\"{}\" use {:.4f} {},".format(func.__name__, tt, TIME_NAME[unit])) + + _store_timer_result(func.__name__, tt) + + return wrapper + + return decorator + + +def time_perf_avg(unit: int, count: int, skip_iter: int = 0, cuda: bool = False): + """ + Overview: + A decorator that averages the execution time of a function. + Arguments: + - unit (int): 0 for s timer, 1 for ms timer, 2 for us timer + - time_list (List): User-supplied list for staging execution times. + - count (int): Loop count. + - skip_iter (int, optional): Skip the first n iter times. + - cuda (bool, optional): Whether CUDA operation occurred within the timing range. + """ + time_list = [] + + if skip_iter >= count: + logging.error("skip_iter:[{}] must >= count:[{}]".format(skip_iter, count)) + return None + + def decorator(func): + + @functools.wraps(func) + def wrapper(idx, *args, **kw): + timer = EasyTimer(cuda=cuda) + with timer: + func(*args, **kw) + + if idx < skip_iter: + return + + time_list.append(timer.value * TIME_UNIT[unit]) + if idx == count - 1: + avg_tt = sum(time_list) / len(time_list) + logging.info( + "\"{}\": repeat[{}], avg_time[{:.4f}]{},".format( + func.__name__, len(time_list), avg_tt, TIME_NAME[unit] + ) + ) + + _store_timer_result(func.__name__, avg_tt) + time_list.clear() + + return wrapper + + return decorator + + +def dtype_2_byte(dtype: torch.dtype) -> int: + return TYPE_MAP[dtype] + + +def tensor_size_beauty_print(length: int, dtype: torch.dtype) -> tuple: + return byte_beauty_print(length * dtype_2_byte(dtype)) + + +def byte_beauty_print(nbytes: int) -> tuple: + """ + Overview: + Output the bytes in a human-readable format. + Arguments: + - nbytes (int): number of bytes. + + Returns: + tuple: tuple of formatted bytes and units. + """ + unit_dict = [("GB", 1024 * 1024 * 1024), ("MB", 1024 * 1024), ("KB", 1024), ("B", 1)] + + for item in unit_dict: + if nbytes // item[1] > 0: + return nbytes / item[1], item[0] + + return nbytes, "B" diff --git a/ding/utils/lock_helper.py b/ding/utils/lock_helper.py index cb4a9c13b5..02c31c2191 100644 --- a/ding/utils/lock_helper.py +++ b/ding/utils/lock_helper.py @@ -2,6 +2,7 @@ import multiprocessing import threading import platform +import functools from enum import Enum, unique from readerwriterlock import rwlock @@ -12,6 +13,19 @@ fcntl = None +class DummyLock: + """ + DummyLock can be used in codes where locks are not required. + Reduce unnecessary code. + """ + + def acquire(self): + pass + + def release(self): + pass + + @unique class LockContextType(Enum): """ @@ -19,11 +33,15 @@ class LockContextType(Enum): """ THREAD_LOCK = 1 PROCESS_LOCK = 2 + DUMMY_LOCK = 3 + CONDITION_LOCK = 4 _LOCK_TYPE_MAPPING = { LockContextType.THREAD_LOCK: threading.Lock, LockContextType.PROCESS_LOCK: multiprocessing.Lock, + LockContextType.DUMMY_LOCK: DummyLock, + LockContextType.CONDITION_LOCK: threading.Condition } @@ -118,3 +136,24 @@ def get_file_lock(name: str, op: str) -> None: except Exception as e: pass return FcntlContext(lock_name) + + +def synchronized(func): + """ + Overview: + thread lock decorator. + Arguments: + - func ([type]): A function that needs to be protected by a lock. + """ + func.__lock__ = threading.Lock() + + def decorator(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with func.__lock__: + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/dizoo/atari/example/atari_dqn_dist_ddp.py b/dizoo/atari/example/atari_dqn_dist_ddp.py index 6b615abb21..85b498d455 100644 --- a/dizoo/atari/example/atari_dqn_dist_ddp.py +++ b/dizoo/atari/example/atari_dqn_dist_ddp.py @@ -14,7 +14,6 @@ from dizoo.atari.envs.atari_env import AtariEnv from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config - logging.getLogger().setLevel(logging.INFO) main_config.exp_name = 'pong_dqn_seed0_ditask_dist_ddp' diff --git a/dizoo/atari/example/atari_dqn_dist_rdma.py b/dizoo/atari/example/atari_dqn_dist_rdma.py index 71fb1d64a1..e852108bda 100644 --- a/dizoo/atari/example/atari_dqn_dist_rdma.py +++ b/dizoo/atari/example/atari_dqn_dist_rdma.py @@ -8,15 +8,17 @@ from ding.framework import task, ding_init from ding.framework.context import OnlineRLContext from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ - eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \ - online_logger + eps_greedy_handler, CkptSaver, ContextExchanger, ModelExchanger, nstep_reward_enhancer, termination_checker from ding.utils import set_pkg_seed from dizoo.atari.envs.atari_env import AtariEnv from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config +from ding.utils import EasyTimer +import os +import time def main(): - logging.getLogger().setLevel(logging.INFO) + logger = logging.getLogger().setLevel(logging.DEBUG) main_config.exp_name = 'pong_dqn_seed0_dist_rdma' cfg = compile_config(main_config, create_cfg=create_config, auto=True) ding_init(cfg) @@ -26,46 +28,53 @@ def main(): set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) model = DQN(**cfg.policy.model) - policy = DQNPolicy(cfg.policy, model=model) + + # Consider the case with multiple processes + if task.router.is_active: + # You can use labels to distinguish between workers with different roles, + # here we use node_id to distinguish. + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + else: + task.add_role(task.role.COLLECTOR) + + logging.debug("label {}".format(task.router.labels)) + logging.debug("task role {}".format(task._roles)) if 'learner' in task.router.labels: + policy = DQNPolicy(cfg.policy, model=model) logging.info("Learner running on node {}".format(task.router.node_id)) buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) - task.use( - context_exchanger( - send_keys=["train_iter"], - recv_keys=["trajectories", "episodes", "env_step", "env_episode"], - skip_n_iter=0 - ) - ) - task.use(model_exchanger(model, is_learner=True)) + task.use(ContextExchanger(skip_n_iter=0)) + task.use(ModelExchanger(model)) task.use(nstep_reward_enhancer(cfg)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=1000)) elif 'collector' in task.router.labels: + policy = DQNPolicy(cfg.policy, model=model) logging.info("Collector running on node {}".format(task.router.node_id)) collector_cfg = deepcopy(cfg.env) collector_cfg.is_train = True + logging.info(cfg.env.manager) + logging.info(type(cfg.env.manager)) + # task.router.judge_use_cuda_shm(cfg) + logging.debug("cuda_shared_memory {}".format(cfg.env.manager.cuda_shared_memory)) collector_env = SubprocessEnvManagerV2( env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) - task.use( - context_exchanger( - send_keys=["trajectories", "episodes", "env_step", "env_episode"], - recv_keys=["train_iter"], - skip_n_iter=1 - ) - ) - task.use(model_exchanger(model, is_learner=False)) + task.use(ContextExchanger(skip_n_iter=1)) + task.use(ModelExchanger(model)) task.use(eps_greedy_handler(cfg)) task.use(StepCollector(cfg, policy.collect_mode, collector_env)) task.use(termination_checker(max_env_step=int(1e7))) else: raise KeyError("invalid router labels: {}".format(task.router.labels)) - task.run() + start_time = task.run(max_step=100) + end_time = time.time() + logging.debug("atari iter 99 use {:.4f} s,".format(end_time - start_time)) if __name__ == "__main__": diff --git a/pytest.ini b/pytest.ini index efdeaba023..25c1e374a9 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,5 +10,7 @@ markers = envpooltest other tmp + multiprocesstest + mqbenchmark norecursedirs = ding/hpc_rl/tests