Skip to content

Commit

Permalink
[Enhancement] Speed up mmcv import (#1249)
Browse files Browse the repository at this point in the history
* [Enhancement] Speed import mmcv

* fix missing parse_version

* fix circle dependency

* rename
  • Loading branch information
zhouzaida committed Aug 10, 2021
1 parent 94a677d commit 9fa5de8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 22 deletions.
25 changes: 13 additions & 12 deletions mmcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@
from .logging import get_logger, print_log
from .parrots_jit import jit, skip_no_elena
from .parrots_wrapper import (
CUDA_HOME, TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension,
DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader,
PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
_AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
_MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home)
from .registry import Registry, build_from_cfg
from .trace import is_jit_tracing
__all__ = [
Expand All @@ -54,15 +54,16 @@
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist',
'symlink', 'scandir', 'ProgressBar', 'track_progress',
'track_iter_progress', 'track_parallel_progress', 'Registry',
'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'CUDA_HOME',
'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd',
'_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin',
'_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension',
'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader',
'TORCH_VERSION', 'deprecated_api_warning', 'digit_version',
'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena',
'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm',
'_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm',
'_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd',
'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension',
'DataLoader', 'PoolDataLoader', 'TORCH_VERSION',
'deprecated_api_warning', 'digit_version', 'get_git_hash',
'import_modules_from_strings', 'jit', 'skip_no_elena',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden', 'is_jit_tracing'
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home'
]
3 changes: 2 additions & 1 deletion mmcv/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def collect_env():
for name, device_ids in devices.items():
env_info['GPU ' + ','.join(device_ids)] = name

from mmcv.utils.parrots_wrapper import CUDA_HOME
from mmcv.utils.parrots_wrapper import _get_cuda_home
CUDA_HOME = _get_cuda_home()
env_info['CUDA_HOME'] = CUDA_HOME

if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
Expand Down
20 changes: 11 additions & 9 deletions mmcv/utils/parrots_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,26 @@

import torch

from mmcv.utils import digit_version

TORCH_VERSION = torch.__version__

is_rocm_pytorch = False
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.5')):
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and

def is_rocm_pytorch() -> bool:
is_rocm = False
if TORCH_VERSION != 'parrots':
try:
from torch.utils.cpp_extension import ROCM_HOME
is_rocm = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False
except ImportError:
pass
return is_rocm


def _get_cuda_home():
if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import CUDA_HOME
else:
if is_rocm_pytorch:
if is_rocm_pytorch():
from torch.utils.cpp_extension import ROCM_HOME
CUDA_HOME = ROCM_HOME
else:
Expand Down Expand Up @@ -86,7 +89,6 @@ def _get_norm():
return _BatchNorm, _InstanceNorm, SyncBatchNorm_


CUDA_HOME = _get_cuda_home()
_ConvNd, _ConvTransposeMixin = _get_conv()
DataLoader, PoolDataLoader = _get_dataloader()
BuildExtension, CppExtension, CUDAExtension = _get_extension()
Expand Down

0 comments on commit 9fa5de8

Please sign in to comment.