Skip to content

Commit

Permalink
[Fix] Support BoolTensor and LongTensor on Ascend NPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ginray committed Mar 24, 2023
1 parent cbb6714 commit fbe04b7
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions mmengine/structures/instance_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@
import numpy as np
import torch

from mmengine.device import is_npu_available
from .base_data_element import BaseDataElement

IndexType = Union[str, slice, int, list, torch.LongTensor,
torch.cuda.LongTensor, torch.BoolTensor,
torch.cuda.BoolTensor, np.ndarray]
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor,
torch.npu.BoolTensor] if is_npu_available() else \
Union[torch.BoolTensor, torch.cuda.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor,
torch.npu.LongTensor] if is_npu_available() else \
Union[torch.LongTensor, torch.cuda.LongTensor]

IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor,
np.ndarray]


# Modified from
Expand Down Expand Up @@ -165,9 +172,8 @@ def __getitem__(self, item: IndexType) -> 'InstanceData':
# More details in https://github.com/numpy/numpy/issues/9464
item = item.astype(np.int64) if item.dtype == np.int32 else item
item = torch.from_numpy(item)
assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor))
assert isinstance(item, (str, slice, int, LongTypeTensor.__args__,
BoolTypeTensor.__args__))

if isinstance(item, str):
return getattr(self, item)
Expand All @@ -183,7 +189,7 @@ def __getitem__(self, item: IndexType) -> 'InstanceData':
if isinstance(item, torch.Tensor):
assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.'
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
if isinstance(item, BoolTypeTensor.__args__):
assert len(item) == len(self), 'The shape of the ' \
'input(BoolTensor) ' \
f'{len(item)} ' \
Expand All @@ -202,8 +208,7 @@ def __getitem__(self, item: IndexType) -> 'InstanceData':
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')):
# convert to indexes from BoolTensor
if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)):
if isinstance(item, BoolTypeTensor.__args__):
indexes = torch.nonzero(item).view(
-1).cpu().numpy().tolist()
else:
Expand Down

0 comments on commit fbe04b7

Please sign in to comment.