Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Add support for Ascend device #847

Merged
merged 2 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions mmengine/device/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def is_npu_available() -> bool:
"""Returns True if Ascend PyTorch and npu devices exist."""
try:
import torch_npu # noqa: F401

# Enable operator support for dynamic shape and
# binary operator support on the NPU.
torch.npu.set_compile_mode(jit_compile=False)
wangjiangben-hw marked this conversation as resolved.
Show resolved Hide resolved
except Exception:
return False
return hasattr(torch, 'npu') and torch.npu.is_available()
Expand Down
12 changes: 12 additions & 0 deletions mmengine/model/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ def to(self, *args, **kwargs) -> nn.Module:
Returns:
nn.Module: The model itself.
"""

# Since Torch has not officially merged
# the npu-related fields, using the _parse_to function
# directly will cause the NPU to not be found.
# Here, the input parameters are processed to avoid errors.
if args and isinstance(args[0], str) and 'npu' in args[0]:
args = tuple(
[list(args)[0].replace('npu', torch.npu.native_device)])
if kwargs and 'npu' in str(kwargs.get('device', '')):
kwargs['device'] = kwargs['device'].replace(
'npu', torch.npu.native_device)

device = torch._C._nn._parse_to(*args, **kwargs)[0]
if device is not None:
self._set_device(torch.device(device))
Expand Down
21 changes: 21 additions & 0 deletions mmengine/model/base_model/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ def to(self, *args, **kwargs) -> nn.Module:
Returns:
nn.Module: The model itself.
"""

# Since Torch has not officially merged
# the npu-related fields, using the _parse_to function
# directly will cause the NPU to not be found.
# Here, the input parameters are processed to avoid errors.
if args and isinstance(args[0], str) and 'npu' in args[0]:
args = tuple(
[list(args)[0].replace('npu', torch.npu.native_device)])
if kwargs and 'npu' in str(kwargs.get('device', '')):
kwargs['device'] = kwargs['device'].replace(
'npu', torch.npu.native_device)

device = torch._C._nn._parse_to(*args, **kwargs)[0]
if device is not None:
self._device = torch.device(device)
Expand All @@ -104,6 +116,15 @@ def cuda(self, *args, **kwargs) -> nn.Module:
self._device = torch.device(torch.cuda.current_device())
return super().cuda()

def npu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`

Returns:
nn.Module: The model itself.
"""
self._device = torch.device(torch.npu.current_device())
return super().npu()

def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`

Expand Down
11 changes: 11 additions & 0 deletions mmengine/structures/base_data_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,17 @@ def cuda(self) -> 'BaseDataElement':
new_data.set_data(data)
return new_data

# Tensor-like methods
def npu(self) -> 'BaseDataElement':
"""Convert all tensors to NPU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.npu()
data = {k: v}
new_data.set_data(data)
return new_data

# Tensor-like methods
def detach(self) -> 'BaseDataElement':
"""Detach all tensors in data."""
Expand Down