Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
wangguoteng.p committed Feb 13, 2023
1 parent ca25b27 commit 44dcf13
Show file tree
Hide file tree
Showing 15 changed files with 185 additions and 81 deletions.
3 changes: 0 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@ WORKERS_COMMAND := $(if ${WORKERS},-n ${WORKERS} --dist=loadscope,)
DURATIONS ?= 10
DURATIONS_COMMAND := $(if ${DURATIONS},--durations=${DURATIONS},)

TIMEOUT_LIMIT ?= 300

docs:
$(MAKE) -C ${DING_DIR}/docs html

unittest:
pytest ${TEST_DIR} \
--timeout=${TIMEOUT_LIMIT} \
--cov-report=xml \
--cov-report term-missing \
--cov=${COV_DIR} \
Expand All @@ -39,7 +37,6 @@ algotest:

cudatest:
pytest ${TEST_DIR} \
--timeout=${TIMEOUT_LIMIT} \
-sv -m cudatest

envpooltest:
Expand Down
8 changes: 0 additions & 8 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,3 @@ coverage:
target: auto
threshold: 0.5%
if_ci_failed: success #success, failure, error, ignore

# fix me
# The unittests of the torchrpc module are tested by different runners and cannot be included
# in the test_unittest's coverage report. To keep CI happy, we don't count torchrpc related coverage.
ignore:
- ./ding/framework/message_queue/torch_rpc.py
- ./ding/framework/message_queue/tests/test_torch_rpc.py
- ./ding/framework/message_queue/perfs/*
4 changes: 1 addition & 3 deletions ding/data/shm_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,7 @@ def __init__(
self.copy_on_get = copy_on_get
self.shape = shape
self.device = device
# We don't want the buffer to be involved in the computational graph
with torch.no_grad():
self.buffer = torch.zeros(reduce(lambda x, y: x * y, shape), dtype=ttype, device=self.device)
self.buffer = torch.zeros(reduce(lambda x, y: x * y, shape), dtype=ttype, device=self.device)

def fill(self, src_arr: Union[np.ndarray, torch.Tensor]) -> None:
if self.ctype is np.ndarray:
Expand Down
114 changes: 91 additions & 23 deletions ding/data/tests/test_shm_buffer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from ding.data.shm_buffer import ShmBuffer, ShmBufferCuda
from ding.compatibility import torch_ge_1121

import pytest
import numpy as np
import timeit
import torch
import time
from ding.data.shm_buffer import ShmBuffer, ShmBufferCuda


def subprocess_np_shm(shm_buf):
Expand All @@ -14,8 +12,9 @@ def subprocess_np_shm(shm_buf):
print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res)))


def subprocess_cuda_shared_tensor(shm_buf_np, shm_buf_torch, event_run):
event_run.wait()
def subprocess_cuda_shared_tensor(shm_buf_np, shm_buf_torch, event_wait, event_fire, copy_on_get):
event_wait.wait()
event_wait.clear()
rtensor = shm_buf_torch.get()
assert isinstance(rtensor, torch.Tensor)
assert rtensor.device == torch.device('cuda:0')
Expand All @@ -26,12 +25,25 @@ def subprocess_cuda_shared_tensor(shm_buf_np, shm_buf_torch, event_run):
assert isinstance(rarray, np.ndarray)
assert rarray.dtype == np.dtype(np.float32)
assert rarray.dtype == np.dtype(np.float32)
assert rtensor.sum() == 1024 * 1024

shm_buf_torch.fill(torch.zeros((1024, 1024), dtype=torch.float32, device=torch.device('cuda:0')))
shm_buf_np.fill(np.zeros((1024, 1024), dtype=np.float32))

event_fire.set()

if copy_on_get:
event_wait.wait()
shm_buf_torch.buffer[0] = 9.0
shm_buf_np.buffer[0] = 9.0
event_fire.set()

del shm_buf_np
del shm_buf_torch

res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.get(), repeat=5, number=1000)
print("CUDA-shared-tensor (torch) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.get(), repeat=5, number=1000)
print("CUDA-shared-tensor (numpy) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))

def subprocess_cuda_shared_tensor_case2(shm_buf_np, shm_buf_torch, event_wait):
event_wait.wait()
del shm_buf_np
del shm_buf_torch

Expand All @@ -49,42 +61,98 @@ def test_shm_buffer():
@pytest.mark.benchmark
@pytest.mark.cudatest
# @pytest.mark.multiprocesstest
def test_cuda_shm():
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
@pytest.mark.parametrize("copy_on_get", [True, False])
def test_cuda_shm(copy_on_get):
if torch.cuda.is_available():
import torch.multiprocessing as mp
ctx = mp.get_context('spawn')

event_run = ctx.Event()
shm_buf_np = ShmBufferCuda(np.dtype(np.float32), shape=(1024, 1024), copy_on_get=True)
shm_buf_torch = ShmBufferCuda(torch.float32, shape=(1024, 1024), copy_on_get=True)
proc = ctx.Process(target=subprocess_cuda_shared_tensor, args=[shm_buf_np, shm_buf_torch, event_run])
event_fire, event_wait = ctx.Event(), ctx.Event()
shm_buf_np = ShmBufferCuda(np.dtype(np.float32), shape=(1024, 1024), copy_on_get=copy_on_get)
shm_buf_torch = ShmBufferCuda(torch.float32, shape=(1024, 1024), copy_on_get=copy_on_get)
proc = ctx.Process(
target=subprocess_cuda_shared_tensor, args=[shm_buf_np, shm_buf_torch, event_fire, event_wait, copy_on_get]
)
proc.start()

ltensor = torch.ones((1024, 1024), dtype=torch.float32).cuda(0 if torch.cuda.device_count() == 1 else 1)
larray = np.random.rand(1024, 1024).astype(np.float32)
ltensor = torch.ones((1024, 1024), dtype=torch.float32, device=torch.device('cuda:0'))
larray = np.ones((1024, 1024), dtype=np.float32)
shm_buf_torch.fill(ltensor)
shm_buf_np.fill(larray)

res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.fill(ltensor), repeat=5, number=1000)
print("CUDA-shared-tensor (torch) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.fill(larray), repeat=5, number=1000)
print("CUDA-shared-tensor (numpy) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))

rtensor = shm_buf_torch.get()
assert isinstance(rtensor, torch.Tensor)
assert rtensor.device == torch.device('cuda:0')
assert rtensor.shape == ltensor.shape
assert rtensor.dtype == ltensor.dtype
assert rtensor.sum().item() == 1024 * 1024

rarray = shm_buf_np.get()
assert isinstance(rarray, np.ndarray)
assert larray.shape == rarray.shape
assert larray.dtype == rarray.dtype
assert larray.sum() == 1024 * 1024

event_fire.set()
event_wait.wait()
event_wait.clear()
rtensor = shm_buf_torch.get()
assert isinstance(rtensor, torch.Tensor)
assert rtensor.device == torch.device('cuda:0')
assert rtensor.shape == ltensor.shape
assert rtensor.dtype == ltensor.dtype
assert rtensor.sum().item() == 0

rarray = shm_buf_np.get()
assert isinstance(rarray, np.ndarray)
assert rarray.shape == larray.shape
assert rarray.dtype == larray.dtype
assert rarray.sum() == 0

event_run.set()
if copy_on_get:
event_fire.set()
event_wait.wait()
assert shm_buf_torch.buffer[0].item() == 9.0
assert shm_buf_np.buffer[0] == 9.0

# Keep producer process running until all consumers exits.
proc.join()

del shm_buf_np
del shm_buf_torch


@pytest.mark.benchmark
@pytest.mark.cudatest
# @pytest.mark.multiprocesstest
@pytest.mark.parametrize("copy_on_get", [True, False])
def test_cudabuff_perf(copy_on_get):
if torch.cuda.is_available():
import torch.multiprocessing as mp
ctx = mp.get_context('spawn')

event_fire, event_wait = ctx.Event(), ctx.Event()
shm_buf_np = ShmBufferCuda(np.dtype(np.float32), shape=(1024, 1024), copy_on_get=copy_on_get)
shm_buf_torch = ShmBufferCuda(torch.float32, shape=(1024, 1024), copy_on_get=copy_on_get)
proc = ctx.Process(target=subprocess_cuda_shared_tensor_case2, args=[shm_buf_np, shm_buf_torch, event_fire])
proc.start()

ltensor = torch.ones((1024, 1024), dtype=torch.float32, device=torch.device('cuda:0'))
larray = np.ones((1024, 1024), dtype=np.float32)
shm_buf_torch.fill(ltensor)
shm_buf_np.fill(larray)

res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.fill(ltensor), repeat=5, number=1000)
print("CUDA-shared-tensor (torch) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.fill(larray), repeat=5, number=1000)
print("CUDA-shared-tensor (numpy) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))

res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.get(), repeat=5, number=1000)
print("CUDA-shared-tensor (torch) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.get(), repeat=5, number=1000)
print("CUDA-shared-tensor (numpy) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
event_fire.set()
proc.join()

del shm_buf_np
del shm_buf_torch
13 changes: 5 additions & 8 deletions ding/entry/cli_ditask.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ def print_version(ctx: Context, param: Option, value: bool) -> None:
@click.option("--platform-spec", type=str, help="Platform specific configure.")
@click.option("--platform", type=str, help="Platform type: slurm, k8s.")
@click.option(
"--mq-type",
type=str,
default="nng",
help="Class type of message queue, i.e. nng, redis, torchrpc:cuda, torchrpc:cpu."
"--mq-type", type=str, default="nng", help="Class type of message queue, i.e. nng, redis, cuda, torchrpc:cpu."
)
@click.option("--redis-host", type=str, help="Redis host.")
@click.option("--redis-port", type=int, help="Redis port.")
Expand Down Expand Up @@ -173,10 +170,10 @@ def _cli_ditask(
node_ids = node_ids.split(",")
node_ids = list(map(lambda i: int(i), node_ids))
use_cuda = False
if mq_type == "torchrpc:cuda" or mq_type == "torchrpc:cpu":
mq_type, use_cuda = mq_type.split(":")
if use_cuda == "cuda":
use_cuda = True
if mq_type == "cuda":
mq_type, use_cuda = "torchrpc", True
if mq_type == "torchrpc:cpu":
mq_type, use_cuda = "torchrpc", False
if local_cuda_devices:
local_cuda_devices = local_cuda_devices.split(",")
local_cuda_devices = list(map(lambda s: s.strip(), local_cuda_devices))
Expand Down
6 changes: 6 additions & 0 deletions ding/envs/env_manager/subprocess_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ def __init__(
if not self._auto_reset:
assert not self._reset_inplace, "reset_inplace is unavailable when auto_reset=False."

if self._cfg.cuda_shared_memory and not self._cuda_shared_memory:
logging.warning(
"Option 'cuda_shared_memory' is true but 'shared_memory' is False, 'cuda_shared_memory'"
" will not be used."
)

def _create_state(self) -> None:
r"""
Overview:
Expand Down
43 changes: 42 additions & 1 deletion ding/framework/message_queue/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,42 @@
# Notes on using torchrpc

## Performance
We conducted performance tests in a k8s environment equipped with A100-80GB and 200G HCA.

### Intra-node GPU-P2P performance

| test case(unit:ms) | 1.25 KB | 20.00 KB | 1.25 MB | 10.00 MB | 40.00 M | 640.00 M | 1.25GB |
| ------------------ | ------- | -------- | ------- | -------- | ------- | -------- | -------- |
| shm | 0.3605 | 0.352 | 0.9924 | 7.1229 | 47.9575 | 798.8635 | 1548.782 |
| nccl-nvlink | 0.1969 | 0.1104 | 0.2162 | 0.3285 | 0.4532 | 3.3166 | 5.3828 |
| cuda-shared-tensor | 0.5307 | 0.578 | 0.9643 | 0.5908 | 1.2449 | 5.3707 | 9.686 |

### Inter-node GPU-P2P performance

| test case(unit:ms) | 20.00 KB | 1.25 MB | 10.00 MB | 40.00 M | 640.00 M | 1.25GB | 2.50 GB |
| ------------------------ | -------- | ------- | -------- | -------- | --------- | --------- | ---------- |
| nng-TCP | 5.7353 | 9.6782 | 30.5187 | 172.9719 | 3450.7418 | 7083.6372 | 14072.1213 |
| nccl-TCP | 0.0826 | 1.321 | 31.7813 | 128.0672 | 1259.72 | 2477.2957 | 5157.7578 |
| nccl-IB | 0.0928 | 0.5618 | 2.1134 | 7.1768 | 120.131 | 260.2628 | 518.8091 |
| nccl-GDR (PXN<->PXN) | 0.5541 | 45.601 | 9.3636 | 19.3071 | 108.11 | 280.0556 | 527.9732 |
| torchrpc-TCP | 5.6691 | 5.4707 | 14.0155 | 39.4443 | 580.333 | 1154.0793 | 2297.3776 |
| torchrpc-IB | 21.3884 | 4.4093 | 5.9105 | 22.3012 | 130.249 | 236.8084 | 477.2389 |
| torchrpc-GDR (PXN<->PXN) | 20.5018 | 23.2081 | 15.6427 | 7.5357* | 48.7812 | 77.2657 | 143.4112 |

### Atari performance
Performance of dizoo/atari/example/atari_dqn_dist_rdma.py
- memory: "32Gi"
- cpu: 16
- gpu: A100


| test case(unit:s) | avg |
| ----------------- | ------- |
| TCP-nng | 127.64 |
| torchrpc-CP | 29.3906 |
| torchrpc-IB | 28.7763 |


## Problems you may encounter

Message queue of Torchrpc uses [tensorpipe](https://github.com/pytorch/tensorpipe) as a communication backend, a high-performance modular tensor-p2p communication library. However, several tensorpipe defects have been found in the test, which may make it difficult for you to use it.
Expand All @@ -10,4 +47,8 @@ Tensorpipe is not container aware. Processes can find themselves on the same phy

### 2. RDMA and fork subprocess

Tensorpipe does not consider the case of calling [fork(2)](https://man7.org/linux/man-pages/man2/fork.2.html) when using RDMA. If the corresponding initialization measures are not performed when using RDMA, using fork will cause serious problems, refer to [here](https://www.rdmamojo.com/2012/05/24/ibv_fork_init/). Therefore, if you start ditask in the IB/RoCE network environment, please specify the environment variables `IBV_FORK_SAFE=1` and `RDMAV_FORK_SAFE=1` , so that ibverbs will automatically initialize fork support.
Tensorpipe does not consider the case of calling [fork(2)](https://man7.org/linux/man-pages/man2/fork.2.html) when using RDMA. If the corresponding initialization measures are not performed when using RDMA, using fork will cause serious problems, refer to [here](https://www.rdmamojo.com/2012/05/24/ibv_fork_init/). Therefore, if you start ditask in the IB/RoCE network environment, please specify the environment variables `IBV_FORK_SAFE=1` and `RDMAV_FORK_SAFE=1` , so that ibverbs will automatically initialize fork support.

### 3. GPU direct RDMA

If you use torchrpc in an environment that supports GPU direct RDMA, if the size of the tensor transmitted in rpc is very small (less than 32B), segmentfault may occur. See [issue.](https://github.com/pytorch/pytorch/issues/57136) We are tracking this bug and hope it can be resolved eventually.
7 changes: 4 additions & 3 deletions ding/framework/message_queue/perfs/perf_shm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from ditk import logging
from ding.framework.supervisor import RecvPayload, SendPayload, Supervisor, ChildType
from ding.envs.env_manager.subprocess_env_manager import ShmBufferContainer, ShmBuffer
from ding.utils.comm_perf_helper import tensor_size_beauty_print, byte_beauty_print, \
from ding.utils.comm_perf_helper import tensor_size_beauty_print, \
dtype_2_byte, TENSOR_SIZE_LIST, print_timer_result_csv
from ding.utils import byte_beauty_print

import torch
import numpy as np
Expand Down Expand Up @@ -37,7 +38,7 @@ def cuda_shm_callback(payload: RecvPayload, buffers: Any):
assert tensor.device == torch.device('cuda:1')


class Recvier:
class Receiver:

def step(self, idx: int, __start_time):
return {"idx": idx, "start_time": __start_time}
Expand All @@ -56,7 +57,7 @@ def __init__(self, gpu_tensors, buffers, ctx, is_cuda_buffer):
_shm_callback = shm_callback
else:
_shm_callback = cuda_shm_callback
self.register(Recvier, shm_buffer=self.buffers, shm_callback=_shm_callback)
self.register(Receiver, shm_buffer=self.buffers, shm_callback=_shm_callback)
super().start_link()

def _send_recv_callback(self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None):
Expand Down
3 changes: 2 additions & 1 deletion ding/framework/message_queue/perfs/perf_torchrpc_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

from ding.utils.data.structure.lifo_deque import LifoDeque
from ding.framework.message_queue.torch_rpc import DeviceMap, TORCHRPCMQ, RPCEvent
from ding.utils.comm_perf_helper import tensor_size_beauty_print, byte_beauty_print, \
from ding.utils.comm_perf_helper import tensor_size_beauty_print, \
dtype_2_byte, DO_PERF, time_perf_avg, time_perf_once, print_timer_result_csv
from ding.utils import byte_beauty_print

LENGTH = 5
REPEAT = 2
Expand Down
15 changes: 9 additions & 6 deletions ding/framework/middleware/functional/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,15 @@ def _rollout(ctx: "OnlineRLContext"):
# torchrpc currently uses "cuda:0" as the transmission device by default,
# so all data on the cpu side is copied to "cuda:0" here. In fact this
# copy is unnecessary, because torchrpc can support both cpu side and gpu
# side data to communicate using RDMA, but mixing the two transfer types
# will cause a bug, see issue:
# Because we have copied the large payload "obs" and "next_obs" from the
# collector's subprocess to "cuda:0" in advance, the copy operation here
# will not have too much overhead.
# side data to communicate using RDMA.
# But we met a bug in unittest, see: https://github.com/pytorch/pytorch/issues/57136
# We adopted some strategies to avoid bug.
# 1. Try not to mix cpu and gpu arg in one rpc.
# Because we have copied the large payload "obs" and "next_obs" from the
# collector's subprocess to "cuda:0" in advance, the copy operation here
# will not have too much overhead.
# 2. Don't make tensor size too small when using gpu direct RDMA.

if use_cuda_shared_memory:
transition = to_device(transition, "cuda:0")
transitions.append(timestep.env_id, transition)
Expand All @@ -149,6 +153,5 @@ def _rollout(ctx: "OnlineRLContext"):
env_episode_id[timestep.env_id] = current_id
current_id += 1
ctx.env_episode += 1
# TODO log

return _rollout
2 changes: 1 addition & 1 deletion ding/framework/middleware/functional/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):

if ctx.train_data is None: # no enough data from data fetcher
return
# data = ctx.train_data.to(policy._device)
data = ctx.train_data.to(policy._device)
train_output = policy.forward(ctx.train_data)
nonlocal last_log_iter
if ctx.train_iter - last_log_iter >= log_freq:
Expand Down

0 comments on commit 44dcf13

Please sign in to comment.