From 44dcf13f7ae74eb63041c25cf8bb376bfd6d014c Mon Sep 17 00:00:00 2001 From: "wangguoteng.p" Date: Mon, 13 Feb 2023 12:26:16 +0800 Subject: [PATCH] polish --- Makefile | 3 - codecov.yml | 8 -- ding/data/shm_buffer.py | 4 +- ding/data/tests/test_shm_buffer.py | 114 ++++++++++++++---- ding/entry/cli_ditask.py | 13 +- .../env_manager/subprocess_env_manager.py | 6 + ding/framework/message_queue/README.md | 43 ++++++- .../framework/message_queue/perfs/perf_shm.py | 7 +- .../message_queue/perfs/perf_torchrpc_nccl.py | 3 +- .../middleware/functional/collector.py | 15 ++- .../middleware/functional/trainer.py | 2 +- ding/utils/__init__.py | 6 +- ding/utils/comm_perf_helper.py | 21 +--- ding/utils/lock_helper.py | 2 +- ding/utils/log_helper.py | 19 +++ 15 files changed, 185 insertions(+), 81 deletions(-) diff --git a/Makefile b/Makefile index cc53a7ca57..b892e63b4b 100644 --- a/Makefile +++ b/Makefile @@ -17,14 +17,12 @@ WORKERS_COMMAND := $(if ${WORKERS},-n ${WORKERS} --dist=loadscope,) DURATIONS ?= 10 DURATIONS_COMMAND := $(if ${DURATIONS},--durations=${DURATIONS},) -TIMEOUT_LIMIT ?= 300 docs: $(MAKE) -C ${DING_DIR}/docs html unittest: pytest ${TEST_DIR} \ - --timeout=${TIMEOUT_LIMIT} \ --cov-report=xml \ --cov-report term-missing \ --cov=${COV_DIR} \ @@ -39,7 +37,6 @@ algotest: cudatest: pytest ${TEST_DIR} \ - --timeout=${TIMEOUT_LIMIT} \ -sv -m cudatest envpooltest: diff --git a/codecov.yml b/codecov.yml index c2a03fbfbe..0779ada773 100644 --- a/codecov.yml +++ b/codecov.yml @@ -6,11 +6,3 @@ coverage: target: auto threshold: 0.5% if_ci_failed: success #success, failure, error, ignore - -# fix me -# The unittests of the torchrpc module are tested by different runners and cannot be included -# in the test_unittest's coverage report. To keep CI happy, we don't count torchrpc related coverage. -ignore: - - ./ding/framework/message_queue/torch_rpc.py - - ./ding/framework/message_queue/tests/test_torch_rpc.py - - ./ding/framework/message_queue/perfs/* diff --git a/ding/data/shm_buffer.py b/ding/data/shm_buffer.py index 875a7210c7..478ca105bd 100644 --- a/ding/data/shm_buffer.py +++ b/ding/data/shm_buffer.py @@ -158,9 +158,7 @@ def __init__( self.copy_on_get = copy_on_get self.shape = shape self.device = device - # We don't want the buffer to be involved in the computational graph - with torch.no_grad(): - self.buffer = torch.zeros(reduce(lambda x, y: x * y, shape), dtype=ttype, device=self.device) + self.buffer = torch.zeros(reduce(lambda x, y: x * y, shape), dtype=ttype, device=self.device) def fill(self, src_arr: Union[np.ndarray, torch.Tensor]) -> None: if self.ctype is np.ndarray: diff --git a/ding/data/tests/test_shm_buffer.py b/ding/data/tests/test_shm_buffer.py index 2125735925..097eb28201 100644 --- a/ding/data/tests/test_shm_buffer.py +++ b/ding/data/tests/test_shm_buffer.py @@ -1,11 +1,9 @@ -from ding.data.shm_buffer import ShmBuffer, ShmBufferCuda -from ding.compatibility import torch_ge_1121 - import pytest import numpy as np import timeit import torch import time +from ding.data.shm_buffer import ShmBuffer, ShmBufferCuda def subprocess_np_shm(shm_buf): @@ -14,8 +12,9 @@ def subprocess_np_shm(shm_buf): print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res))) -def subprocess_cuda_shared_tensor(shm_buf_np, shm_buf_torch, event_run): - event_run.wait() +def subprocess_cuda_shared_tensor(shm_buf_np, shm_buf_torch, event_wait, event_fire, copy_on_get): + event_wait.wait() + event_wait.clear() rtensor = shm_buf_torch.get() assert isinstance(rtensor, torch.Tensor) assert rtensor.device == torch.device('cuda:0') @@ -26,12 +25,25 @@ def subprocess_cuda_shared_tensor(shm_buf_np, shm_buf_torch, event_run): assert isinstance(rarray, np.ndarray) assert rarray.dtype == np.dtype(np.float32) assert rarray.dtype == np.dtype(np.float32) + assert rtensor.sum() == 1024 * 1024 + + shm_buf_torch.fill(torch.zeros((1024, 1024), dtype=torch.float32, device=torch.device('cuda:0'))) + shm_buf_np.fill(np.zeros((1024, 1024), dtype=np.float32)) + + event_fire.set() + + if copy_on_get: + event_wait.wait() + shm_buf_torch.buffer[0] = 9.0 + shm_buf_np.buffer[0] = 9.0 + event_fire.set() + + del shm_buf_np + del shm_buf_torch - res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.get(), repeat=5, number=1000) - print("CUDA-shared-tensor (torch) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) - res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.get(), repeat=5, number=1000) - print("CUDA-shared-tensor (numpy) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) +def subprocess_cuda_shared_tensor_case2(shm_buf_np, shm_buf_torch, event_wait): + event_wait.wait() del shm_buf_np del shm_buf_torch @@ -49,42 +61,98 @@ def test_shm_buffer(): @pytest.mark.benchmark @pytest.mark.cudatest # @pytest.mark.multiprocesstest -def test_cuda_shm(): - if torch.cuda.is_available() and torch.cuda.device_count() >= 2: +@pytest.mark.parametrize("copy_on_get", [True, False]) +def test_cuda_shm(copy_on_get): + if torch.cuda.is_available(): import torch.multiprocessing as mp ctx = mp.get_context('spawn') - event_run = ctx.Event() - shm_buf_np = ShmBufferCuda(np.dtype(np.float32), shape=(1024, 1024), copy_on_get=True) - shm_buf_torch = ShmBufferCuda(torch.float32, shape=(1024, 1024), copy_on_get=True) - proc = ctx.Process(target=subprocess_cuda_shared_tensor, args=[shm_buf_np, shm_buf_torch, event_run]) + event_fire, event_wait = ctx.Event(), ctx.Event() + shm_buf_np = ShmBufferCuda(np.dtype(np.float32), shape=(1024, 1024), copy_on_get=copy_on_get) + shm_buf_torch = ShmBufferCuda(torch.float32, shape=(1024, 1024), copy_on_get=copy_on_get) + proc = ctx.Process( + target=subprocess_cuda_shared_tensor, args=[shm_buf_np, shm_buf_torch, event_fire, event_wait, copy_on_get] + ) proc.start() - ltensor = torch.ones((1024, 1024), dtype=torch.float32).cuda(0 if torch.cuda.device_count() == 1 else 1) - larray = np.random.rand(1024, 1024).astype(np.float32) + ltensor = torch.ones((1024, 1024), dtype=torch.float32, device=torch.device('cuda:0')) + larray = np.ones((1024, 1024), dtype=np.float32) shm_buf_torch.fill(ltensor) shm_buf_np.fill(larray) - res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.fill(ltensor), repeat=5, number=1000) - print("CUDA-shared-tensor (torch) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) - res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.fill(larray), repeat=5, number=1000) - print("CUDA-shared-tensor (numpy) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) - rtensor = shm_buf_torch.get() assert isinstance(rtensor, torch.Tensor) assert rtensor.device == torch.device('cuda:0') assert rtensor.shape == ltensor.shape assert rtensor.dtype == ltensor.dtype + assert rtensor.sum().item() == 1024 * 1024 rarray = shm_buf_np.get() assert isinstance(rarray, np.ndarray) assert larray.shape == rarray.shape assert larray.dtype == rarray.dtype + assert larray.sum() == 1024 * 1024 + + event_fire.set() + event_wait.wait() + event_wait.clear() + rtensor = shm_buf_torch.get() + assert isinstance(rtensor, torch.Tensor) + assert rtensor.device == torch.device('cuda:0') + assert rtensor.shape == ltensor.shape + assert rtensor.dtype == ltensor.dtype + assert rtensor.sum().item() == 0 + + rarray = shm_buf_np.get() + assert isinstance(rarray, np.ndarray) + assert rarray.shape == larray.shape + assert rarray.dtype == larray.dtype + assert rarray.sum() == 0 - event_run.set() + if copy_on_get: + event_fire.set() + event_wait.wait() + assert shm_buf_torch.buffer[0].item() == 9.0 + assert shm_buf_np.buffer[0] == 9.0 # Keep producer process running until all consumers exits. proc.join() del shm_buf_np del shm_buf_torch + + +@pytest.mark.benchmark +@pytest.mark.cudatest +# @pytest.mark.multiprocesstest +@pytest.mark.parametrize("copy_on_get", [True, False]) +def test_cudabuff_perf(copy_on_get): + if torch.cuda.is_available(): + import torch.multiprocessing as mp + ctx = mp.get_context('spawn') + + event_fire, event_wait = ctx.Event(), ctx.Event() + shm_buf_np = ShmBufferCuda(np.dtype(np.float32), shape=(1024, 1024), copy_on_get=copy_on_get) + shm_buf_torch = ShmBufferCuda(torch.float32, shape=(1024, 1024), copy_on_get=copy_on_get) + proc = ctx.Process(target=subprocess_cuda_shared_tensor_case2, args=[shm_buf_np, shm_buf_torch, event_fire]) + proc.start() + + ltensor = torch.ones((1024, 1024), dtype=torch.float32, device=torch.device('cuda:0')) + larray = np.ones((1024, 1024), dtype=np.float32) + shm_buf_torch.fill(ltensor) + shm_buf_np.fill(larray) + + res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.fill(ltensor), repeat=5, number=1000) + print("CUDA-shared-tensor (torch) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) + res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.fill(larray), repeat=5, number=1000) + print("CUDA-shared-tensor (numpy) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) + + res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.get(), repeat=5, number=1000) + print("CUDA-shared-tensor (torch) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) + res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.get(), repeat=5, number=1000) + print("CUDA-shared-tensor (numpy) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res))) + event_fire.set() + proc.join() + + del shm_buf_np + del shm_buf_torch diff --git a/ding/entry/cli_ditask.py b/ding/entry/cli_ditask.py index 29af0af2ad..f6b0e7922f 100644 --- a/ding/entry/cli_ditask.py +++ b/ding/entry/cli_ditask.py @@ -58,10 +58,7 @@ def print_version(ctx: Context, param: Option, value: bool) -> None: @click.option("--platform-spec", type=str, help="Platform specific configure.") @click.option("--platform", type=str, help="Platform type: slurm, k8s.") @click.option( - "--mq-type", - type=str, - default="nng", - help="Class type of message queue, i.e. nng, redis, torchrpc:cuda, torchrpc:cpu." + "--mq-type", type=str, default="nng", help="Class type of message queue, i.e. nng, redis, cuda, torchrpc:cpu." ) @click.option("--redis-host", type=str, help="Redis host.") @click.option("--redis-port", type=int, help="Redis port.") @@ -173,10 +170,10 @@ def _cli_ditask( node_ids = node_ids.split(",") node_ids = list(map(lambda i: int(i), node_ids)) use_cuda = False - if mq_type == "torchrpc:cuda" or mq_type == "torchrpc:cpu": - mq_type, use_cuda = mq_type.split(":") - if use_cuda == "cuda": - use_cuda = True + if mq_type == "cuda": + mq_type, use_cuda = "torchrpc", True + if mq_type == "torchrpc:cpu": + mq_type, use_cuda = "torchrpc", False if local_cuda_devices: local_cuda_devices = local_cuda_devices.split(",") local_cuda_devices = list(map(lambda s: s.strip(), local_cuda_devices)) diff --git a/ding/envs/env_manager/subprocess_env_manager.py b/ding/envs/env_manager/subprocess_env_manager.py index fdcc61de17..94e4e46b0e 100644 --- a/ding/envs/env_manager/subprocess_env_manager.py +++ b/ding/envs/env_manager/subprocess_env_manager.py @@ -118,6 +118,12 @@ def __init__( if not self._auto_reset: assert not self._reset_inplace, "reset_inplace is unavailable when auto_reset=False." + if self._cfg.cuda_shared_memory and not self._cuda_shared_memory: + logging.warning( + "Option 'cuda_shared_memory' is true but 'shared_memory' is False, 'cuda_shared_memory'" + " will not be used." + ) + def _create_state(self) -> None: r""" Overview: diff --git a/ding/framework/message_queue/README.md b/ding/framework/message_queue/README.md index 3267dbecfd..4610534575 100644 --- a/ding/framework/message_queue/README.md +++ b/ding/framework/message_queue/README.md @@ -1,5 +1,42 @@ # Notes on using torchrpc +## Performance +We conducted performance tests in a k8s environment equipped with A100-80GB and 200G HCA. + +### Intra-node GPU-P2P performance + +| test case(unit:ms) | 1.25 KB | 20.00 KB | 1.25 MB | 10.00 MB | 40.00 M | 640.00 M | 1.25GB | +| ------------------ | ------- | -------- | ------- | -------- | ------- | -------- | -------- | +| shm | 0.3605 | 0.352 | 0.9924 | 7.1229 | 47.9575 | 798.8635 | 1548.782 | +| nccl-nvlink | 0.1969 | 0.1104 | 0.2162 | 0.3285 | 0.4532 | 3.3166 | 5.3828 | +| cuda-shared-tensor | 0.5307 | 0.578 | 0.9643 | 0.5908 | 1.2449 | 5.3707 | 9.686 | + +### Inter-node GPU-P2P performance + +| test case(unit:ms) | 20.00 KB | 1.25 MB | 10.00 MB | 40.00 M | 640.00 M | 1.25GB | 2.50 GB | +| ------------------------ | -------- | ------- | -------- | -------- | --------- | --------- | ---------- | +| nng-TCP | 5.7353 | 9.6782 | 30.5187 | 172.9719 | 3450.7418 | 7083.6372 | 14072.1213 | +| nccl-TCP | 0.0826 | 1.321 | 31.7813 | 128.0672 | 1259.72 | 2477.2957 | 5157.7578 | +| nccl-IB | 0.0928 | 0.5618 | 2.1134 | 7.1768 | 120.131 | 260.2628 | 518.8091 | +| nccl-GDR (PXN<->PXN) | 0.5541 | 45.601 | 9.3636 | 19.3071 | 108.11 | 280.0556 | 527.9732 | +| torchrpc-TCP | 5.6691 | 5.4707 | 14.0155 | 39.4443 | 580.333 | 1154.0793 | 2297.3776 | +| torchrpc-IB | 21.3884 | 4.4093 | 5.9105 | 22.3012 | 130.249 | 236.8084 | 477.2389 | +| torchrpc-GDR (PXN<->PXN) | 20.5018 | 23.2081 | 15.6427 | 7.5357* | 48.7812 | 77.2657 | 143.4112 | + +### Atari performance +Performance of dizoo/atari/example/atari_dqn_dist_rdma.py +- memory: "32Gi" +- cpu: 16 +- gpu: A100 + + +| test case(unit:s) | avg | +| ----------------- | ------- | +| TCP-nng | 127.64 | +| torchrpc-CP | 29.3906 | +| torchrpc-IB | 28.7763 | + + ## Problems you may encounter Message queue of Torchrpc uses [tensorpipe](https://github.com/pytorch/tensorpipe) as a communication backend, a high-performance modular tensor-p2p communication library. However, several tensorpipe defects have been found in the test, which may make it difficult for you to use it. @@ -10,4 +47,8 @@ Tensorpipe is not container aware. Processes can find themselves on the same phy ### 2. RDMA and fork subprocess -Tensorpipe does not consider the case of calling [fork(2)](https://man7.org/linux/man-pages/man2/fork.2.html) when using RDMA. If the corresponding initialization measures are not performed when using RDMA, using fork will cause serious problems, refer to [here](https://www.rdmamojo.com/2012/05/24/ibv_fork_init/). Therefore, if you start ditask in the IB/RoCE network environment, please specify the environment variables `IBV_FORK_SAFE=1` and `RDMAV_FORK_SAFE=1` , so that ibverbs will automatically initialize fork support. \ No newline at end of file +Tensorpipe does not consider the case of calling [fork(2)](https://man7.org/linux/man-pages/man2/fork.2.html) when using RDMA. If the corresponding initialization measures are not performed when using RDMA, using fork will cause serious problems, refer to [here](https://www.rdmamojo.com/2012/05/24/ibv_fork_init/). Therefore, if you start ditask in the IB/RoCE network environment, please specify the environment variables `IBV_FORK_SAFE=1` and `RDMAV_FORK_SAFE=1` , so that ibverbs will automatically initialize fork support. + +### 3. GPU direct RDMA + +If you use torchrpc in an environment that supports GPU direct RDMA, if the size of the tensor transmitted in rpc is very small (less than 32B), segmentfault may occur. See [issue.](https://github.com/pytorch/pytorch/issues/57136) We are tracking this bug and hope it can be resolved eventually. diff --git a/ding/framework/message_queue/perfs/perf_shm.py b/ding/framework/message_queue/perfs/perf_shm.py index ee9fbc1030..3bb3cfd0cf 100644 --- a/ding/framework/message_queue/perfs/perf_shm.py +++ b/ding/framework/message_queue/perfs/perf_shm.py @@ -3,8 +3,9 @@ from ditk import logging from ding.framework.supervisor import RecvPayload, SendPayload, Supervisor, ChildType from ding.envs.env_manager.subprocess_env_manager import ShmBufferContainer, ShmBuffer -from ding.utils.comm_perf_helper import tensor_size_beauty_print, byte_beauty_print, \ +from ding.utils.comm_perf_helper import tensor_size_beauty_print, \ dtype_2_byte, TENSOR_SIZE_LIST, print_timer_result_csv +from ding.utils import byte_beauty_print import torch import numpy as np @@ -37,7 +38,7 @@ def cuda_shm_callback(payload: RecvPayload, buffers: Any): assert tensor.device == torch.device('cuda:1') -class Recvier: +class Receiver: def step(self, idx: int, __start_time): return {"idx": idx, "start_time": __start_time} @@ -56,7 +57,7 @@ def __init__(self, gpu_tensors, buffers, ctx, is_cuda_buffer): _shm_callback = shm_callback else: _shm_callback = cuda_shm_callback - self.register(Recvier, shm_buffer=self.buffers, shm_callback=_shm_callback) + self.register(Receiver, shm_buffer=self.buffers, shm_callback=_shm_callback) super().start_link() def _send_recv_callback(self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None): diff --git a/ding/framework/message_queue/perfs/perf_torchrpc_nccl.py b/ding/framework/message_queue/perfs/perf_torchrpc_nccl.py index cdf29b063e..4596320696 100644 --- a/ding/framework/message_queue/perfs/perf_torchrpc_nccl.py +++ b/ding/framework/message_queue/perfs/perf_torchrpc_nccl.py @@ -12,8 +12,9 @@ from ding.utils.data.structure.lifo_deque import LifoDeque from ding.framework.message_queue.torch_rpc import DeviceMap, TORCHRPCMQ, RPCEvent -from ding.utils.comm_perf_helper import tensor_size_beauty_print, byte_beauty_print, \ +from ding.utils.comm_perf_helper import tensor_size_beauty_print, \ dtype_2_byte, DO_PERF, time_perf_avg, time_perf_once, print_timer_result_csv +from ding.utils import byte_beauty_print LENGTH = 5 REPEAT = 2 diff --git a/ding/framework/middleware/functional/collector.py b/ding/framework/middleware/functional/collector.py index eeaf77e67a..e0140f0142 100644 --- a/ding/framework/middleware/functional/collector.py +++ b/ding/framework/middleware/functional/collector.py @@ -136,11 +136,15 @@ def _rollout(ctx: "OnlineRLContext"): # torchrpc currently uses "cuda:0" as the transmission device by default, # so all data on the cpu side is copied to "cuda:0" here. In fact this # copy is unnecessary, because torchrpc can support both cpu side and gpu - # side data to communicate using RDMA, but mixing the two transfer types - # will cause a bug, see issue: - # Because we have copied the large payload "obs" and "next_obs" from the - # collector's subprocess to "cuda:0" in advance, the copy operation here - # will not have too much overhead. + # side data to communicate using RDMA. + # But we met a bug in unittest, see: https://github.com/pytorch/pytorch/issues/57136 + # We adopted some strategies to avoid bug. + # 1. Try not to mix cpu and gpu arg in one rpc. + # Because we have copied the large payload "obs" and "next_obs" from the + # collector's subprocess to "cuda:0" in advance, the copy operation here + # will not have too much overhead. + # 2. Don't make tensor size too small when using gpu direct RDMA. + if use_cuda_shared_memory: transition = to_device(transition, "cuda:0") transitions.append(timestep.env_id, transition) @@ -149,6 +153,5 @@ def _rollout(ctx: "OnlineRLContext"): env_episode_id[timestep.env_id] = current_id current_id += 1 ctx.env_episode += 1 - # TODO log return _rollout diff --git a/ding/framework/middleware/functional/trainer.py b/ding/framework/middleware/functional/trainer.py index 28f06472d3..ecc7994b62 100644 --- a/ding/framework/middleware/functional/trainer.py +++ b/ding/framework/middleware/functional/trainer.py @@ -71,7 +71,7 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): if ctx.train_data is None: # no enough data from data fetcher return - # data = ctx.train_data.to(policy._device) + data = ctx.train_data.to(policy._device) train_output = policy.forward(ctx.train_data) nonlocal last_log_iter if ctx.train_iter - last_log_iter >= log_freq: diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index 88f39e0d3d..9d06275eac 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -12,7 +12,7 @@ DEFAULT_K8S_LEARNER_PORT, DEFAULT_K8S_AGGREGATOR_SLAVE_PORT, DEFAULT_K8S_COORDINATOR_PORT, pod_exec_command, \ K8sLauncher from .lock_helper import LockContext, LockContextType, get_file_lock, get_rw_file_lock, synchronized -from .log_helper import build_logger, pretty_print, LoggerFactory +from .log_helper import build_logger, pretty_print, LoggerFactory, byte_beauty_print from .log_writer_helper import DistributedWriter from .orchestrator_launcher import OrchestratorLauncher from .profiler_helper import Profiler, register_profiler @@ -37,5 +37,5 @@ from .pytorch_ddp_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \ allreduce, broadcast, DistContext, allreduce_async, synchronize -from .comm_perf_helper import TENSOR_SIZE_LIST, DO_PERF, tensor_size_beauty_print, byte_beauty_print, \ - dtype_2_byte, time_perf_avg, time_perf_once, print_timer_result_csv +from .comm_perf_helper import TENSOR_SIZE_LIST, DO_PERF, tensor_size_beauty_print, dtype_2_byte, \ + time_perf_avg, time_perf_once, print_timer_result_csv diff --git a/ding/utils/comm_perf_helper.py b/ding/utils/comm_perf_helper.py index 416b794c56..2dda44ba60 100644 --- a/ding/utils/comm_perf_helper.py +++ b/ding/utils/comm_perf_helper.py @@ -4,7 +4,7 @@ from concurrent import futures from ditk import logging from typing import List, Optional, Tuple, Dict, Any -from ding.utils import EasyTimer +from ding.utils import EasyTimer, byte_beauty_print # Data size for some tests UNIT_1_B = 1 @@ -124,22 +124,3 @@ def dtype_2_byte(dtype: torch.dtype) -> int: def tensor_size_beauty_print(length: int, dtype: torch.dtype) -> tuple: return byte_beauty_print(length * dtype_2_byte(dtype)) - - -def byte_beauty_print(nbytes: int) -> tuple: - """ - Overview: - Output the bytes in a human-readable format. - Arguments: - - nbytes (int): number of bytes. - - Returns: - tuple: tuple of formatted bytes and units. - """ - unit_dict = [("GB", 1024 * 1024 * 1024), ("MB", 1024 * 1024), ("KB", 1024), ("B", 1)] - - for item in unit_dict: - if nbytes // item[1] > 0: - return nbytes / item[1], item[0] - - return nbytes, "B" diff --git a/ding/utils/lock_helper.py b/ding/utils/lock_helper.py index 02c31c2191..3dc8f6b9e1 100644 --- a/ding/utils/lock_helper.py +++ b/ding/utils/lock_helper.py @@ -143,7 +143,7 @@ def synchronized(func): Overview: thread lock decorator. Arguments: - - func ([type]): A function that needs to be protected by a lock. + - func ([Callable]): A function that needs to be protected by a lock. """ func.__lock__ = threading.Lock() diff --git a/ding/utils/log_helper.py b/ding/utils/log_helper.py index 3c83e5242f..5b5887e18f 100644 --- a/ding/utils/log_helper.py +++ b/ding/utils/log_helper.py @@ -150,3 +150,22 @@ def pretty_print(result: dict, direct_print: bool = True) -> str: if direct_print: print(string) return string + + +def byte_beauty_print(nbytes: int) -> tuple: + """ + Overview: + Output the bytes in a human-readable format. + Arguments: + - nbytes (int): number of bytes. + + Returns: + tuple: tuple of formatted bytes and units. + """ + unit_dict = [("GB", 1024 ** 3), ("MB", 1024 ** 2), ("KB", 1024), ("B", 1)] + + for item in unit_dict: + if nbytes // item[1] > 0: + return nbytes / item[1], item[0] + + return nbytes, "B"