Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support engine with NPU backend. #572

Merged
merged 20 commits into from Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmengine/device/__init__.py
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_mlu_available, is_mps_available)
is_mlu_available, is_mps_available, is_npu_available)

__all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available', 'is_mps_available'
'is_mlu_available', 'is_mps_available', 'is_npu_available'
]
15 changes: 13 additions & 2 deletions mmengine/device/utils.py
Expand Up @@ -32,6 +32,15 @@ def is_cuda_available() -> bool:
return torch.cuda.is_available()


def is_npu_available() -> bool:
"""Returns True if Ascend PyTorch and npu devices exist."""
try:
import torch_npu # noqa: F401
except Exception:
return False
return hasattr(torch, 'npu') and torch.npu.is_available()


def is_mlu_available() -> bool:
"""Returns True if Cambricon PyTorch and mlu devices exist."""
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
Expand All @@ -49,9 +58,11 @@ def get_device() -> str:
"""Returns the currently existing device type.

Returns:
str: cuda | mlu | mps | cpu.
str: cuda | npu | mlu | mps | cpu.
"""
if is_cuda_available():
if is_npu_available():
return 'npu'
elif is_cuda_available():
return 'cuda'
elif is_mlu_available():
return 'mlu'
Expand Down
12 changes: 9 additions & 3 deletions mmengine/dist/dist.py
Expand Up @@ -20,6 +20,7 @@
get_comm_device, cast_data_device)
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.device import is_npu_available


def _get_reduce_op(name: str) -> torch_dist.ReduceOp:
Expand Down Expand Up @@ -411,7 +412,11 @@ def _broadcast_object_list(object_list: List[Any],
group_backend = get_backend(group)
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
current_device = torch.device('cpu')
if is_nccl_backend:
is_hccl_backend = group_backend == 'hccl'
if is_hccl_backend:
current_device = torch.npu.current_device()
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_nccl_backend:
# See note about using torch.cuda.current_device() here in
# docstring. We cannot simply use my_rank since rank == device is
# not necessarily true.
Expand All @@ -430,7 +435,7 @@ def _broadcast_object_list(object_list: List[Any],
dtype=torch.uint8,
)

if is_nccl_backend:
if is_nccl_backend or is_hccl_backend:
object_tensor = object_tensor.to(current_device)
torch_dist.broadcast(object_tensor, src=src, group=group)
# Deserialize objects using their stored sizes.
Expand Down Expand Up @@ -504,7 +509,8 @@ def broadcast_object_list(data: List[Any],
if group is None:
group = get_default_group()

if digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
if digit_version(TORCH_VERSION) >= digit_version(
'1.8.0') and not is_npu_available():
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
torch_dist.broadcast_object_list(data, src, group)
else:
_broadcast_object_list(data, src, group)
Expand Down
15 changes: 13 additions & 2 deletions mmengine/dist/utils.py
Expand Up @@ -10,7 +10,7 @@
from torch import Tensor
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
from mmengine.device import is_mlu_available
from mmengine.device import is_mlu_available, is_npu_available

from collections.abc import Iterable, Mapping

Expand Down Expand Up @@ -80,6 +80,14 @@ def _init_dist_pytorch(backend, **kwargs) -> None:
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
elif is_npu_available():
import torch_npu # noqa: F401
torch.npu.set_device(rank)
torch_dist.init_process_group(
backend='hccl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
Expand Down Expand Up @@ -437,7 +445,10 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
torch.device: The device of backend.
"""
backend = get_backend(group)
if backend == torch_dist.Backend.NCCL:
if backend == 'hccl':
import torch_npu # noqa: F401
return torch.device('npu', torch.npu.current_device())
elif backend == torch_dist.Backend.NCCL:
return torch.device('cuda', torch.cuda.current_device())
elif backend == 'cncl':
import torch_mlu # noqa: F401
Expand Down
19 changes: 19 additions & 0 deletions mmengine/model/base_model/base_model.py
Expand Up @@ -210,6 +210,25 @@ def cuda(
self._set_device(torch.device(device))
return super().cuda(device)

def npu(
self,
device: Optional[Union[int, str, torch.device]] = None,
wangjiangben-hw marked this conversation as resolved.
Show resolved Hide resolved
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.npu`
additionally.

Returns:
nn.Module: The model itself.

Note:
This generation of NPU(Ascend910) does not support
the use of multiple cards in a single process,
so the index here needs to be consistent with the default device
"""
device = torch.npu.current_device()
self._set_device(device)
return super().npu()

def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
additionally.
Expand Down
9 changes: 7 additions & 2 deletions mmengine/optim/optimizer/amp_optimizer_wrapper.py
Expand Up @@ -3,13 +3,18 @@

import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler

from mmengine.device import is_npu_available, is_cuda_available
wangjiangben-hw marked this conversation as resolved.
Show resolved Hide resolved
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from .optimizer_wrapper import OptimWrapper

if is_npu_available():
from torch.npu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler


@OPTIM_WRAPPERS.register_module()
class AmpOptimWrapper(OptimWrapper):
Expand Down Expand Up @@ -44,7 +49,7 @@ class AmpOptimWrapper(OptimWrapper):
def __init__(self, loss_scale='dynamic', **kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
assert torch.cuda.is_available(), (
assert is_cuda_available() or is_npu_available(), (
'``AmpOptimizerWrapper`` is only available training on gpu')
wangjiangben-hw marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(**kwargs)
self._scale_update_param = None
Expand Down
8 changes: 8 additions & 0 deletions mmengine/optim/optimizer/builder.py
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn

from mmengine.config import Config, ConfigDict
from mmengine.device import is_npu_available
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
from .optimizer_wrapper import OptimWrapper

Expand Down Expand Up @@ -53,6 +54,13 @@ def build_optim_wrapper(model: nn.Module,
constructor_type = optim_wrapper_cfg.pop('constructor',
'DefaultOptimWrapperConstructor')
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)

# Since the current generation of NPU(Ascend 910) only supports
# mixed precision training, here we turn on mixed precision by default
# on the NPU to make the training normal
if is_npu_available():
optim_wrapper_cfg['type'] = 'AmpOptimWrapper'
wangjiangben-hw marked this conversation as resolved.
Show resolved Hide resolved

optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
dict(
type=constructor_type,
Expand Down
7 changes: 5 additions & 2 deletions mmengine/runner/amp.py
Expand Up @@ -5,7 +5,7 @@

import torch

from mmengine.device import get_device
from mmengine.device import get_device, is_npu_available, is_cuda_available
from mmengine.logging import print_log
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
Expand Down Expand Up @@ -86,7 +86,10 @@ def autocast(device_type: Optional[str] = None,
logger='current',
level=logging.WARNING)

if torch.cuda.is_available():
if is_npu_available():
with torch.npu.amp.autocast(enabled=enabled):
yield
elif is_cuda_available():
with torch.cuda.amp.autocast(enabled=enabled):
yield
else:
Expand Down
6 changes: 4 additions & 2 deletions tests/test_device/test_device.py
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
is_mps_available)
is_mps_available, is_npu_available)


def test_get_device():
device = get_device()
if is_cuda_available():
if is_npu_available():
assert device == 'npu'
elif is_cuda_available():
assert device == 'cuda'
elif is_mlu_available():
assert device == 'mlu'
Expand Down