Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,6 @@ flagged/
huggingface_transformers/
diffusers/
!mindnlp/diffusers/
.gradio/
.gradio/
transformers/
!mindnlp/transformers/
10 changes: 9 additions & 1 deletion mindnlp/core/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .configs import ON_A1

if ON_A1:
warnings.warn('910A do not support bfloat16, use float16 instead.')
warnings.warn('MindSpore on GPU/910A do not support bfloat16, use float16 instead.')
bfloat16 = float16

dtype = Type
Expand Down Expand Up @@ -45,6 +45,14 @@ def __gt__(self, other):
float8_e4m3fn = None # TODO: not support fp8 for now
float8_e5m2 = None

uint1 = None
uint2 = None
uint3 = None
uint4 = None
uint5 = None
uint6 = None
uint7 = None

ITEM_SIZE = {
bool : 1,
int8 : 1,
Expand Down
1 change: 1 addition & 0 deletions mindnlp/core/_dynamo/_trace_wrapped_higher_order_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class TransformGetItemToIndex: pass
56 changes: 34 additions & 22 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
import numpy as np
import warnings
import mindspore
from mindspore import Tensor
from mindspore.common.tensor import _TensorMeta
Expand All @@ -16,12 +15,11 @@ class StubTensor: pass
from mindspore._c_expression import Tensor as Tensor_

from . import ops, _dtype
from ._dtype import dtype2np
from ._bind import get_default_device, device_, get_default_dtype
from .configs import use_pyboost, ON_A1
from .storage import UntypedStorage
from ._utils import _rebuild_tensor_v2
from ._C.size import Size
from .types import DEVICE_MAP

DTYPE_ELEMENT_SIZE_MAP = {
mindspore.float64: 8,
Expand All @@ -34,12 +32,6 @@ class StubTensor: pass
mindspore.bool_: 1
}

DEVICE_MAP = {
'GPU': 'cuda',
'Ascend': 'npu',
'CPU': 'cpu'
}

class TypedTensorMeta(_TensorMeta):
def __isinstancecheck__(self, instance):
if not isinstance(instance, Tensor):
Expand Down Expand Up @@ -266,6 +258,8 @@ def __getitem__(self, slices):
for s in slices:
if isinstance(s, range):
s = list(s)
if isinstance(s, np.ndarray):
s = tensor(s)
new_slices += (s,)
slices = new_slices
return origin_getitem(self, slices)
Expand Down Expand Up @@ -298,7 +292,7 @@ def _convert_numpy_slices(self, key):
# 转换单个 NumPy 索引值
elif isinstance(key, np.integer):
return int(key)

# 其他类型(如 int、None)直接返回
else:
return key
Expand Down Expand Up @@ -328,7 +322,7 @@ def __setitem__(self, slices, value):

if 1 in value.shape and self[slices].ndim != value.ndim:
value = value.squeeze()

return origin_setitem(self, slices, value)

Tensor.__setitem__ = __setitem__
Expand Down Expand Up @@ -362,16 +356,8 @@ def nbytes(self):
Tensor.unsqueeze = ops.unsqueeze
StubTensor.unsqueeze = ops.unsqueeze

def log_softmax(self, dim=-1, dtype=None):
if use_pyboost():
return mindspore.mint.nn.functional.log_softmax(self, dim=dim, dtype=dtype)
out = mindspore.ops.log_softmax(self, dim)
if dtype is not None:
out = out.to(dtype)
return out

Tensor.log_softmax = log_softmax
StubTensor.log_softmax = log_softmax
Tensor.log_softmax = ops.log_softmax
StubTensor.log_softmax = ops.log_softmax

def untyped_storage(self):
return UntypedStorage(self)
Expand Down Expand Up @@ -457,7 +443,7 @@ def unfold(self, dimension, size, step):
def new(self, *shape):
if not isinstance(shape[0], int):
return tensor(shape[0], dtype=self.dtype)
return ops.empty(*shape, dtype=self.dtype)
return ops.empty(*shape, dtype=self.dtype, device=self.device)

Tensor.new = new
StubTensor.new = new
Expand Down Expand Up @@ -638,9 +624,26 @@ def new_tensor(self, data, *, dtype=None, device=None, requires_grad=False, layo
Tensor.fill_diagonal_ = ops.inplace_fill_diagonal
StubTensor.fill_diagonal_ = ops.inplace_fill_diagonal

Tensor.fill_ = ops.inplace_fill
StubTensor.fill_ = ops.inplace_fill

Tensor.zero_ = ops.inplace_zero
StubTensor.zero_ = ops.inplace_zero

Tensor.uniform_ = ops.inplace_uniform
StubTensor.uniform_ = ops.inplace_uniform

Tensor.random_ = ops.inplace_random
StubTensor.random_ = ops.inplace_random


Tensor.triu_ = ops.inplace_triu
StubTensor.triu_ = ops.inplace_triu

Tensor.masked_fill_ = ops.inplace_masked_fill
StubTensor.masked_fill_ = ops.inplace_masked_fill


@property
def real(self):
return ops.real(self)
Expand Down Expand Up @@ -780,6 +783,15 @@ def tobytes(self):
Tensor.cuda = cpu
StubTensor.cuda = cpu

Tensor.nonzero = ops.nonzero
StubTensor.nonzero = ops.nonzero

Tensor.clamp_ = ops.inplace_clamp
StubTensor.clamp_ = ops.inplace_clamp

Tensor.copy_ = ops.inplace_copy
StubTensor.copy_ = ops.inplace_copy

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
16 changes: 15 additions & 1 deletion mindnlp/core/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import mindspore
from mindspore import get_rng_state, set_rng_state, manual_seed
from mindspore.hal import *
from mindspore.runtime import memory_reserved as ms_memory_reserved, \
memory_allocated as ms_memory_allocated

from mindnlp import core

Expand Down Expand Up @@ -42,4 +44,16 @@ def __enter__(self):
def __exit__(self, type: Any, value: Any, traceback: Any):
return False

OutOfMemoryError = RuntimeError
OutOfMemoryError = RuntimeError

def is_bf16_supported():
return False

def mem_get_info(index):
return (1024, 1024)

def memory_reserved(device=None):
return ms_memory_reserved()

def memory_allocated(device=None):
return ms_memory_allocated()
6 changes: 4 additions & 2 deletions mindnlp/core/distributed/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
Replicate = None
class Replicate: pass
class Shard:
def __init__(self, *args, **kargs):
pass
class DTensor(): pass
Placement = None
Shard = None
39 changes: 39 additions & 0 deletions mindnlp/core/fx/_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import textwrap
from typing import Any, Callable, TypeVar


_BACK_COMPAT_OBJECTS: dict[Any, None] = {}
_MARKED_WITH_COMPATIBILITY: dict[Any, None] = {}


_T = TypeVar("_T")


def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]:
if is_backward_compatible:

def mark_back_compat(fn: _T) -> _T:
docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "")
docstring += """
.. note::
Backwards-compatibility for this API is guaranteed.
"""
fn.__doc__ = docstring
_BACK_COMPAT_OBJECTS.setdefault(fn)
_MARKED_WITH_COMPATIBILITY.setdefault(fn)
return fn

return mark_back_compat
else:

def mark_not_back_compat(fn: _T) -> _T:
docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "")
docstring += """
.. warning::
This API is experimental and is *NOT* backward-compatible.
"""
fn.__doc__ = docstring
_MARKED_WITH_COMPATIBILITY.setdefault(fn)
return fn

return mark_not_back_compat
5 changes: 5 additions & 0 deletions mindnlp/core/fx/_symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from typing import Any, Callable, NamedTuple, Optional, Union

_wrapped_fns_to_patch: dict[tuple[int, str], dict] = {}
_is_fx_tracing_flag = False


def is_fx_tracing():
return _is_fx_tracing_flag

def wrap(fn_or_name: Union[str, Callable]):
"""
Expand Down
3 changes: 3 additions & 0 deletions mindnlp/core/fx/proxy.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
class Proxy:
pass

class ParameterProxy(Proxy):
pass
1 change: 1 addition & 0 deletions mindnlp/core/nn/attention/flex_attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
BlockMask = None
flex_attention = None
create_block_mask = None
_DEFAULT_SPARSE_BLOCK_SIZE = None
76 changes: 73 additions & 3 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,75 @@ def pad(input, pad, mode='constant', value=None):
return ops.pad(input, new_pad, mode, value)

def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean'):
return _nllloss_nd(input, target, weight, ignore_index, reduction)
if input.device.type == 'npu':
return _nllloss_nd(input, target, weight, ignore_index, reduction)
return _inner_nll_loss(input, target, weight, ignore_index, reduction)

def _inner_nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0):
ndim = inputs.ndim
if ndim == 2:
ret = _nll_loss(inputs, target, -1, weight, ignore_index, reduction, label_smoothing)
elif ndim == 4:
ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing)
elif ndim == 1:
ret = _nll_loss(inputs, target, 0, weight, ignore_index, reduction, label_smoothing)
else:
n = inputs.shape[0]
c = inputs.shape[1]
out_size = (n,) + inputs.shape[2:]
inputs = inputs.view((n, c, 1, -1))
target = target.view((n, 1, -1))
if reduction != 'none':
ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing)
else:
ret = _nll_loss(inputs, target, 1, weight, ignore_index, label_smoothing=label_smoothing)
ret = ret.view(out_size)
return ret

def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0):
"""nll loss inner function"""
if target.ndim == inputs.ndim - 1:
target = target.expand_dims(target_dim)
if ignore_index is not None:
non_pad_mask = ops.equal(target, ignore_index)
target = target.masked_fill(non_pad_mask, ops.cast(0, target.dtype))
else:
non_pad_mask = target
if weight is not None:
loss_weights = ops.gather(weight, target, 0)
orig_shape = inputs.shape
if inputs.ndim != 2:
inputs = inputs.view(orig_shape[:2] + (-1,))
weight = weight.view(weight.shape + (1,))
weighted_inputs = inputs * weight
weighted_inputs = weighted_inputs.view(orig_shape)
loss = ops.neg(ops.gather_d(weighted_inputs, target_dim, target))
smooth_loss = ops.neg(weighted_inputs.sum(axis=target_dim, keepdims=True))
else:
loss = ops.neg(ops.gather_d(inputs, target_dim, target))
smooth_loss = ops.neg(inputs.sum(axis=target_dim, keepdims=True))
loss_weights = ops.ones_like(loss)

if ignore_index is not None:
loss = loss.masked_fill(non_pad_mask, ops.cast(0, loss.dtype))
loss_weights = loss_weights.masked_fill(non_pad_mask, ops.cast(0, loss_weights.dtype))
smooth_loss = smooth_loss.masked_fill(non_pad_mask, ops.cast(0, smooth_loss.dtype))

loss = loss.squeeze(target_dim)
smooth_loss = smooth_loss.squeeze(target_dim)

if reduction == 'sum':
loss = loss.sum()
smooth_loss = smooth_loss.sum()
if reduction == 'mean':
loss = loss.sum() / loss_weights.sum()
smooth_loss = smooth_loss.sum() / loss_weights.sum()

eps_i = label_smoothing / inputs.shape[target_dim]
if label_smoothing != 0:
loss = (1. - label_smoothing) * loss + eps_i * smooth_loss

return loss

def _nllloss_nd(input, target, weight=None, ingore_index=-100, reduction='mean'):
"""nllloss_nd inner function"""
Expand Down Expand Up @@ -402,8 +469,11 @@ def cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean
return _cross_entropy_for_probabilities(input, target, weight, reduction, label_smoothing, class_dim,
n_classes)
# for class indices
return _cross_entropy_for_class_indices(input, target, weight, ignore_index, reduction, label_smoothing,
class_dim, n_classes)
if input.device.type == 'npu':
return _cross_entropy_for_class_indices(input, target, weight, ignore_index, reduction, label_smoothing,
class_dim, n_classes)
return _inner_nll_loss(log_softmax(input, class_dim), target, weight, ignore_index, reduction, label_smoothing)


def _cross_entropy_for_probabilities(input, target, weight, reduction, label_smoothing, class_dim, n_classes):
"""cross_entropy inner function for class probabilities"""
Expand Down
2 changes: 2 additions & 0 deletions mindnlp/core/nn/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def requires_grad(self, value):
self.handle.remove()
delattr(self, 'handle')

def retain_grad(self):
pass

class UninitializedParameter(Parameter):
def __init__(self, input_data=None, requires_grad=True):
Expand Down
6 changes: 4 additions & 2 deletions mindnlp/core/ops/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ._inner import call_ms_func

sort_out = namedtuple('sort_out', ['sorted', 'indices'])
sort_out = namedtuple('sort_out', ['values', 'indices'])
topk_out = namedtuple('topk_out', ['values', 'indices'])
# allclose
has_allclose = hasattr(mindspore.mint, 'allclose')
Expand All @@ -30,6 +30,8 @@ def argsort(input, dim=-1, descending=False, stable=False):
def eq(input, other, *, out=None):
if use_pyboost() and has_eq:
return call_ms_func(mindspore.mint.eq, input, other, out=out)
if isinstance(other, str):
return False
return call_ms_func(ops.eq, input, other, out=out)

# equal
Expand Down Expand Up @@ -176,7 +178,7 @@ def sort(input, *, dim=-1, descending=False, stable=False):
out = mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable)
else:
out = ops.sort(input, dim, descending)
return sort_out(sorted=out[0], indices=out[1])
return sort_out(values=out[0], indices=out[1])

# topk
has_topk = hasattr(mindspore.mint, 'topk')
Expand Down
Loading
Loading