Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,9 @@ def record_stream(self, stream):
Tensor.record_stream = record_stream
StubTensor.record_stream = record_stream

Tensor.scatter = ops.scatter
StubTensor.scatter = ops.scatter

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
2 changes: 1 addition & 1 deletion mindnlp/core/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

SOC = MSContext.get_instance().get_ascend_soc_version()
DEVICE_TARGET = mindspore.get_context('device_target')
SUPPORT_BF16 = DEVICE_TARGET == 'Ascend' and SOC not in ['ascend910', 'ascend310b1', 'ascend310b4']
SUPPORT_BF16 = DEVICE_TARGET == 'Ascend' and SOC not in ['ascend910', 'ascend310b']
ON_A1 = not SUPPORT_BF16
ON_ORANGE_PI = '310b' in SOC
USE_PYBOOST = DEVICE_TARGET == 'Ascend'
Expand Down
16 changes: 15 additions & 1 deletion mindnlp/core/npu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

import mindspore
from mindspore import get_rng_state, set_rng_state, manual_seed
from mindspore.hal import *
from mindspore.runtime import memory_reserved as ms_memory_reserved, \
memory_allocated as ms_memory_allocated, StreamCtx as StreamContext, Stream, empty_cache, \
reset_peak_memory_stats, reset_max_memory_allocated, max_memory_allocated, synchronize, \
current_stream
from mindspore.device_context.ascend import device_count

from mindnlp import core
from ..configs import SUPPORT_BF16

FloatTensor = core.FloatTensor
HalfTensor = core.FloatTensor
Expand All @@ -28,6 +33,15 @@ def set_device(device):
def _lazy_call(callable, **kwargs):
callable()

def is_bf16_supported():
return SUPPORT_BF16

def memory_allocated(device=None):
return ms_memory_allocated()

def memory_reserved(device=None):
return ms_memory_reserved()

class device:
r"""Context-manager that changes the selected device.

Expand Down
16 changes: 10 additions & 6 deletions mindnlp/core/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ def gather(input, dim, index):
_complex = _get_cache_prim(ops.Complex)()
return _complex(real_part, imag_part)

if use_pyboost() and has_gather:
if use_pyboost() and has_gather and not ON_ORANGE_PI:
return mindspore.mint.gather(input, dim, index)

index = ops.where(index < input.shape[dim], index, index - input.shape[dim])
return ops.gather_elements(input, dim, index)
index = core.where(index < input.shape[dim], index, index - input.shape[dim])
if not ON_ORANGE_PI:
return ops.gather_elements(input, dim, index)
return tf_gather(input, index, dim, batch_dims=dim)

def gather_nd(input, indices):
return ops.gather_nd(input, indices)
Expand Down Expand Up @@ -195,10 +197,12 @@ def select(input, dim, index):
# scatter
has_scatter = hasattr(mindspore.mint, 'scatter')
def scatter(input, dim, index, src):
if use_pyboost() and has_scatter:
if use_pyboost() and has_scatter and not ON_ORANGE_PI:
return mindspore.mint.scatter(input, dim, index, src)
if not isinstance(src, mindspore.Tensor):
src = ops.full(index.shape, src, dtype=input.dtype)
if input.dtype == mindspore.bool_:
return ops.tensor_scatter_elements(input.int(), index, src.int(), dim).bool()
return ops.tensor_scatter_elements(input, index, src, dim)

def tf_scatter_nd_update(input, indices, updates):
Expand Down Expand Up @@ -352,7 +356,7 @@ def _take_along_dim_helper(self, indices, dim):
def take_along_dim(input, indices, dim=None, *, out=None):
if dim:
self_broadcasted, indices_broadcasted, dim = _take_along_dim_helper(input, indices, dim)
return self_broadcasted.gather(dim, indices_broadcasted)
return gather(self_broadcasted, dim, indices_broadcasted)
return input.view(-1).gather(0, indices.view(-1))

# tensor_split
Expand Down Expand Up @@ -427,7 +431,7 @@ def where(condition, *args, out=None):
input = mindspore.tensor(input, dtype=mindspore.float32)
other = finfo(input.dtype).min

if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
output = mindspore.mint.where(condition, input, other)
else:
output = condition * input + (~condition) * other
Expand Down
6 changes: 3 additions & 3 deletions mindnlp/core/ops/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import mindspore
from mindspore import ops
from ..configs import use_pyboost
from ..configs import use_pyboost, ON_ORANGE_PI

from ._inner import call_ms_func

Expand Down Expand Up @@ -64,7 +64,7 @@ def greater(input, other, *, out=None):
# isclose
has_isclose = hasattr(mindspore.mint, 'isclose')
def isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
if use_pyboost() and has_isclose:
if use_pyboost() and has_isclose and not ON_ORANGE_PI:
return mindspore.mint.isclose(input, other, rtol, atol, equal_nan)
return mindspore.tensor(np.isclose(input.numpy(), other.numpy(), rtol, atol, equal_nan))

Expand Down Expand Up @@ -174,7 +174,7 @@ def not_equal(input, other):
# sort
has_sort = hasattr(mindspore.mint, 'sort')
def sort(input, *, dim=-1, descending=False, stable=False):
if use_pyboost() and has_sort:
if use_pyboost() and has_sort and not ON_ORANGE_PI:
out = mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable)
else:
out = ops.sort(input, dim, descending)
Expand Down
4 changes: 2 additions & 2 deletions mindnlp/core/ops/pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import mindspore
from mindspore import ops
from ..configs import use_pyboost, ON_A1
from ..configs import use_pyboost, ON_A1, ON_ORANGE_PI
from ._inner import call_ms_func

from mindnlp import core
Expand Down Expand Up @@ -582,7 +582,7 @@ def igammac(input, other):


def mul(input, other, *, out=None):
if use_pyboost() and has_mul:
if use_pyboost() and has_mul and not ON_ORANGE_PI:
out = mindspore.mint.mul(input, other)
else:
if input.dtype == mindspore.bool_:
Expand Down
2 changes: 2 additions & 0 deletions mindnlp/utils/torch_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def exec_module(self, module):
class ProxyModule(type(module)):
def __getattr__(_, name):
# 动态导入实际模块中的属性
if DEVICE_TARGET == 'Ascend':
name = name.replace('cuda', 'npu')
try:
target_module = importlib.import_module(self.target_name)
except ImportError as e:
Expand Down
1 change: 0 additions & 1 deletion tests/run_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import pytest
import mindspore
Expand Down
Loading