Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mindnlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
# for different ascend devices
if platform.system().lower() == 'linux':
SOC = MSContext.get_instance().get_ascend_soc_version()
if ('910b' not in SOC and '310' not in SOC) or version.parse(mindspore.__version__) < version.parse('2.4.0'):
os.environ["MS_ALLOC_CONF"] = 'enable_vmm:True,vmm_align_size:2MB'
# enable vmm since only vmm can release device memory when del tensor.
os.environ["MS_ALLOC_CONF"] = 'enable_vmm:True,vmm_align_size:2MB'

if SOC in ('ascend910', 'ascend310b'):
# context.set_context(ascend_config={"precision_mode": "allow_mix_precision"})
Expand Down
6 changes: 5 additions & 1 deletion mindnlp/core/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ml_dtypes import bfloat16 as np_bfloat16

bool_alias = bool
float_alias = float
int_alias = int

if ON_A1:
warnings.warn('MindSpore on GPU/910A do not support bfloat16, use float16 instead.')
Expand Down Expand Up @@ -116,5 +118,7 @@ def __gt__(self, other):
dtype2np[bfloat16] = np_bfloat16

py2dtype = {
bool_alias: bool
bool_alias: bool,
float_alias: float,
int_alias: int64
}
47 changes: 44 additions & 3 deletions mindnlp/core/_prims/ascend.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numbers
import mindspore
from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.ops.auto_generate import gen_ops_prim
from mindspore.ops.auto_generate import pyboost_inner_prim
from mindspore._c_expression import _empty_instance
from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2
from mindspore.ops.operations.nn_ops import AllFinite

from mindspore.ops.auto_generate.gen_ops_prim import MaxPoolWithIndices, MaxPoolWithMask
from mindnlp import core
from mindnlp.core._C import default_generator

Expand Down Expand Up @@ -105,7 +105,12 @@ def tile(*args):
__all__.append('tile')

def pad_v3(input_x, padding, mode='constant', value=None):
pad_op = ops.PadV3(mode=mode, paddings_contiguous=True).set_device('CPU')
pad_op = ops.PadV3(mode=mode, paddings_contiguous=True).set_device('Ascend')
if input_x.dtype == core.bool:
input_x = input_x.to(core.int32)
out = pad_op(input_x, padding, value)
return cast(out, core.bool)

if isinstance(value, (float, int)):
value = core.tensor(value, dtype=input_x.dtype)
return pad_op(input_x, padding, value)
Expand Down Expand Up @@ -248,3 +253,39 @@ def triu(input, diagonal):
return pyboost_inner_prim.triu_impl(input, diagonal)

__all__.append('triu')

masked_scatter_op = ops.MaskedScatter().set_device('Ascend')
def masked_scatter(input, mask, source):
return masked_scatter_op(input, mask, source)

__all__.append('masked_scatter')

def roll(*args):
return pyboost_inner_prim.roll_impl(*args)

__all__.append('roll')

lgamma_op = ops.Lgamma().set_device('Ascend')
def lgamma(input):
return lgamma_op(input)

__all__.append('lgamma')

def max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices):
strides = stride if (stride is not None) else kernel_size
if return_indices:
max_pool_func_ = _get_cache_prim(MaxPoolWithIndices)(kernel_size, strides, padding, dilation, ceil_mode)
out, indices = max_pool_func_(input)
else:
max_pool_func_ = _get_cache_prim(MaxPoolWithMask)(kernel_size, strides, padding, dilation, ceil_mode)
out, indices = max_pool_func_(input)
if return_indices:
return out, indices
return out

__all__.append('max_pool2d')

def unique_consecutive(*args):
return pyboost_inner_prim.unique_consecutive_impl(*args)

__all__.append('unique_consecutive')
41 changes: 40 additions & 1 deletion mindnlp/core/_prims/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,4 +290,43 @@ def linalg_vector_norm(input, p, dim, keepdim, dtype):
out = Tensor_(shape=tuple(new_shape), dtype=dtype)
return core.Tensor(out)

__all__.append('linalg_vector_norm')
__all__.append('linalg_vector_norm')

def erfinv(input):
return input
__all__.append('erfinv')


def stop_gradient(input):
out = Tensor_(shape=input.shape, dtype=input.dtype)
return core.Tensor(out)

__all__.append('stop_gradient')

def log(input):
return input
__all__.append('log')

def mul(input, other):
out = Tensor_(shape=input.shape, dtype=input.dtype)
return core.Tensor(out)
__all__.append('mul')

def randn(size, seed, offset, dtype):
out = Tensor_(shape=size, dtype=dtype)
return core.Tensor(out)

__all__.append('randn')

def zeros_like_ext(input, *args, **kwargs):
out = Tensor_(shape=input.shape, dtype=input.dtype)
return core.Tensor(out)
__all__.append('zeros_like_ext')

def inplace_add_ext(input, other, alpha):
return input
__all__.append('inplace_add_ext')

def clamp_scalar(input, *args):
return input
__all__.append('clamp_scalar')
112 changes: 109 additions & 3 deletions mindnlp/core/_prims/numpy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numbers
import numpy as np
import scipy
from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim
from mindnlp import core

__all__ = []
Expand Down Expand Up @@ -28,6 +30,8 @@ def arange(start, end, step, dtype):
def div(input, other):
if not isinstance(input, numbers.Number):
input = input.numpy()
if input.dtype == np.int64:
input = input.astype(np.int32)
elif not isinstance(other, numbers.Number):
other = other.numpy()
out = np.divide(input, other)
Expand Down Expand Up @@ -98,7 +102,15 @@ def cast(input, dtype):
__all__.append('cast')

def getitem(input, slice):
out = input.asnumpy()[slice]
if isinstance(slice, tuple):
new_slice = ()
for s in slice:
if isinstance(s, core.Tensor):
s = s.numpy()
new_slice += (s,)
else:
new_slice = slice
out = input.asnumpy()[new_slice]
if not isinstance(out, np.ndarray):
out = np.array(out)
return core.Tensor.from_numpy(out)
Expand Down Expand Up @@ -233,6 +245,8 @@ def concat(tensors, dim):

def abs(input):
out = np.abs(input.numpy())
if not isinstance(out, np.ndarray):
out = np.array(out)
return core.Tensor.from_numpy(out)

__all__.append('abs')
Expand Down Expand Up @@ -277,6 +291,8 @@ def identity(input):
# def non_zero()
def isclose(input, other, rtol, atol, equal_nan):
out = np.isclose(input.numpy(), other.numpy(), rtol, atol, equal_nan)
if not isinstance(out, np.ndarray):
out = np.array(out)
return core.Tensor.from_numpy(out)

__all__.append('isclose')
Expand Down Expand Up @@ -308,8 +324,11 @@ def index_select(input, dim, index):
__all__.append('index_select')

def rand_ext(size, seed, offset, dtype):
out = np.random.randn(*size).astype(core.dtype2np[dtype])
return core.Tensor.from_numpy(out[0])
out = np.random.randn(*size)
if not isinstance(out, np.ndarray):
out = np.array(out)
out = out.astype(core.dtype2np[dtype])
return core.Tensor.from_numpy(out)

__all__.append('rand_ext')

Expand Down Expand Up @@ -438,6 +457,9 @@ def less(input, other):
other = other.numpy()

out = input < other
if not isinstance(out, np.ndarray):
out = np.array(out)

return core.Tensor.from_numpy(out)

__all__.append('less')
Expand Down Expand Up @@ -529,3 +551,87 @@ def randn(size, seed, offset, dtype):
return core.Tensor.from_numpy(out)

__all__.append('randn')

def erfinv(input):
out = scipy.special.erfinv(input)
return core.Tensor.from_numpy(out)

__all__.append('erfinv')

def inplace_add_ext(input, other, alpha):
if not isinstance(other, numbers.Number):
other = other.numpy()
out = input.numpy() + other * alpha
input.assign_value(core.Tensor.from_numpy(out))
return input

__all__.append('inplace_add_ext')

def pow_tensor_scalar(input, other):
out = np.power(input.numpy(), other)
return core.Tensor.from_numpy(out)

__all__.append('pow_tensor_scalar')

stop_gradient_op = ops.StopGradient().set_device('CPU')
def stop_gradient(*args):
return stop_gradient_op(*args)

__all__.append('stop_gradient')

def fmod_scalar(input, other):
out = np.fmod(input.numpy(), other)
return core.Tensor.from_numpy(out)

__all__.append('fmod_scalar')

def argmax_with_value(input, dim, keepdim):
indices = np.argmax(input.numpy(), dim, keepdims=keepdim)
values = np.max(input.numpy(), dim, keepdims=keepdim)

if not isinstance(indices, np.ndarray):
indices = np.array(indices)
if not isinstance(values, np.ndarray):
values = np.array(values)
return core.Tensor.from_numpy(indices), core.Tensor.from_numpy(values)

__all__.append('argmax_with_value')

def argmax_ext(input, dim, keepdim):
indices = np.argmax(input.numpy(), dim, keepdims=keepdim)
if not isinstance(indices, np.ndarray):
indices = np.array(indices)
return core.Tensor.from_numpy(indices)
__all__.append('argmax_ext')


def log(input):
out = np.log(input.numpy())
return core.Tensor.from_numpy(out)

__all__.append('log')

def eye(n, m, dtype):
out = np.eye(n, m, dtype=core.dtype2np[dtype])
return core.Tensor.from_numpy(out)

__all__.append('eye')

def lin_space_ext(start, end, steps, dtype):
out = np.linspace(start, end, steps, dtype=core.dtype2np[dtype])
return core.Tensor.from_numpy(out)

__all__.append('lin_space_ext')

def upsample_bilinear2d(input, output_size, scale_factors, align_corners):
resize = _get_cache_prim(ops.ResizeBilinearV2)(align_corners, not align_corners).set_device('CPU')
return resize(input, output_size)

__all__.append('upsample_bilinear2d')

def split_with_size(tensor, split_size_or_sections, dim):
out = np.array_split(tensor.numpy(), np.cumsum(split_size_or_sections[:-1]), dim)
out = [core.Tensor.from_numpy(o) for o in out]
return out

__all__.append('split_with_size')
Loading
Loading