Skip to content

Commit

Permalink
[Fix] Support BoolTensor and LongTensor on Ascend NPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ginray committed Apr 10, 2023
1 parent 8bf1eca commit 98464dc
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions mmengine/structures/instance_data.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from collections.abc import Sized
from typing import List, Union
from typing import Any, List, Union

import numpy as np
import torch

from mmengine.device import get_device
from .base_data_element import BaseDataElement

IndexType = Union[str, slice, int, list, torch.LongTensor,
torch.cuda.LongTensor, torch.BoolTensor,
torch.cuda.BoolTensor, np.ndarray]
BoolTypeTensor: Union[Any]
LongTypeTensor: Union[Any]

if get_device() == 'npu':
BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor]
else:
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]

IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor,
BoolTypeTensor, np.ndarray]


# Modified from
Expand Down Expand Up @@ -156,6 +166,7 @@ def __getitem__(self, item: IndexType) -> 'InstanceData':
Returns:
:obj:`InstanceData`: Corresponding values.
"""
assert isinstance(item, IndexType.__args__)
if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
Expand All @@ -165,9 +176,6 @@ def __getitem__(self, item: IndexType) -> 'InstanceData':
# More details in https://github.com/numpy/numpy/issues/9464
item = item.astype(np.int64) if item.dtype == np.int32 else item
item = torch.from_numpy(item)
assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor))

if isinstance(item, str):
return getattr(self, item)
Expand All @@ -183,7 +191,7 @@ def __getitem__(self, item: IndexType) -> 'InstanceData':
if isinstance(item, torch.Tensor):
assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.'
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
if isinstance(item, BoolTypeTensor.__args__):
assert len(item) == len(self), 'The shape of the ' \
'input(BoolTensor) ' \
f'{len(item)} ' \
Expand All @@ -202,8 +210,7 @@ def __getitem__(self, item: IndexType) -> 'InstanceData':
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')):
# convert to indexes from BoolTensor
if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)):
if isinstance(item, BoolTypeTensor.__args__):
indexes = torch.nonzero(item).view(
-1).cpu().numpy().tolist()
else:
Expand Down

0 comments on commit 98464dc

Please sign in to comment.