Skip to content

Commit

Permalink
add npu device support
Browse files Browse the repository at this point in the history
  • Loading branch information
wangjiangben-hw committed Dec 26, 2022
1 parent eb803f8 commit aed46ab
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 0 deletions.
1 change: 1 addition & 0 deletions mmengine/device/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def is_npu_available() -> bool:
"""Returns True if Ascend PyTorch and npu devices exist."""
try:
import torch_npu # noqa: F401
torch.npu.set_compile_mode(jit_compile=False)
except Exception:
return False
return hasattr(torch, 'npu') and torch.npu.is_available()
Expand Down
12 changes: 12 additions & 0 deletions mmengine/model/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ def to(self, *args, **kwargs) -> nn.Module:
Returns:
nn.Module: The model itself.
"""

# Since Torch has not officially merged
# the npu-related fields, using the _parse_to function
# directly will cause the NPU to not be found.
# Here, the input parameters are processed to avoid errors.
if args and isinstance(args[0], str) and 'npu' in args[0]:
args = tuple([list(args)])[0].replace( # type: ignore
'npu', torch.npu.native_device)
if kwargs and 'npu' in str(kwargs.get('device', '')):
kwargs['device'] = kwargs['device'].replace(
'npu', torch.npu.native_device)

device = torch._C._nn._parse_to(*args, **kwargs)[0]
if device is not None:
self._set_device(torch.device(device))
Expand Down
21 changes: 21 additions & 0 deletions mmengine/model/base_model/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ def to(self, *args, **kwargs) -> nn.Module:
Returns:
nn.Module: The model itself.
"""

# Since Torch has not officially merged
# the npu-related fields, using the _parse_to function
# directly will cause the NPU to not be found.
# Here, the input parameters are processed to avoid errors.
if args and isinstance(args[0], str) and 'npu' in args[0]:
args = tuple([list(args)])[0].replace( # type: ignore
'npu', torch.npu.native_device)
if kwargs and 'npu' in str(kwargs.get('device', '')):
kwargs['device'] = kwargs['device'].replace(
'npu', torch.npu.native_device)

device = torch._C._nn._parse_to(*args, **kwargs)[0]
if device is not None:
self._device = torch.device(device)
Expand All @@ -104,6 +116,15 @@ def cuda(self, *args, **kwargs) -> nn.Module:
self._device = torch.device(torch.cuda.current_device())
return super().cuda()

def npu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`
Returns:
nn.Module: The model itself.
"""
self._device = torch.device(torch.npu.current_device())
return super().npu()

def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`
Expand Down
11 changes: 11 additions & 0 deletions mmengine/structures/base_data_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,17 @@ def cuda(self) -> 'BaseDataElement':
new_data.set_data(data)
return new_data

# Tensor-like methods
def npu(self) -> 'BaseDataElement':
"""Convert all tensors to NPU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.npu()
data = {k: v}
new_data.set_data(data)
return new_data

# Tensor-like methods
def detach(self) -> 'BaseDataElement':
"""Detach all tensors in data."""
Expand Down

0 comments on commit aed46ab

Please sign in to comment.