diff --git a/.gitignore b/.gitignore index 8c133a656..202643cae 100644 --- a/.gitignore +++ b/.gitignore @@ -170,4 +170,6 @@ flagged/ huggingface_transformers/ diffusers/ !mindnlp/diffusers/ -.gradio/ \ No newline at end of file +.gradio/ +transformers/ +!mindnlp/transformers/ \ No newline at end of file diff --git a/mindnlp/core/_dtype.py b/mindnlp/core/_dtype.py index 53788b962..d480e7f9c 100644 --- a/mindnlp/core/_dtype.py +++ b/mindnlp/core/_dtype.py @@ -7,7 +7,7 @@ from .configs import ON_A1 if ON_A1: - warnings.warn('910A do not support bfloat16, use float16 instead.') + warnings.warn('MindSpore on GPU/910A do not support bfloat16, use float16 instead.') bfloat16 = float16 dtype = Type @@ -45,6 +45,14 @@ def __gt__(self, other): float8_e4m3fn = None # TODO: not support fp8 for now float8_e5m2 = None +uint1 = None +uint2 = None +uint3 = None +uint4 = None +uint5 = None +uint6 = None +uint7 = None + ITEM_SIZE = { bool : 1, int8 : 1, diff --git a/mindnlp/core/_dynamo/_trace_wrapped_higher_order_op.py b/mindnlp/core/_dynamo/_trace_wrapped_higher_order_op.py new file mode 100644 index 000000000..08a078b7c --- /dev/null +++ b/mindnlp/core/_dynamo/_trace_wrapped_higher_order_op.py @@ -0,0 +1 @@ +class TransformGetItemToIndex: pass diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 14f97a5d5..02ffdbbf1 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -1,6 +1,5 @@ import math import numpy as np -import warnings import mindspore from mindspore import Tensor from mindspore.common.tensor import _TensorMeta @@ -16,12 +15,11 @@ class StubTensor: pass from mindspore._c_expression import Tensor as Tensor_ from . import ops, _dtype -from ._dtype import dtype2np from ._bind import get_default_device, device_, get_default_dtype -from .configs import use_pyboost, ON_A1 from .storage import UntypedStorage from ._utils import _rebuild_tensor_v2 from ._C.size import Size +from .types import DEVICE_MAP DTYPE_ELEMENT_SIZE_MAP = { mindspore.float64: 8, @@ -34,12 +32,6 @@ class StubTensor: pass mindspore.bool_: 1 } -DEVICE_MAP = { - 'GPU': 'cuda', - 'Ascend': 'npu', - 'CPU': 'cpu' -} - class TypedTensorMeta(_TensorMeta): def __isinstancecheck__(self, instance): if not isinstance(instance, Tensor): @@ -266,6 +258,8 @@ def __getitem__(self, slices): for s in slices: if isinstance(s, range): s = list(s) + if isinstance(s, np.ndarray): + s = tensor(s) new_slices += (s,) slices = new_slices return origin_getitem(self, slices) @@ -298,7 +292,7 @@ def _convert_numpy_slices(self, key): # 转换单个 NumPy 索引值 elif isinstance(key, np.integer): return int(key) - + # 其他类型(如 int、None)直接返回 else: return key @@ -328,7 +322,7 @@ def __setitem__(self, slices, value): if 1 in value.shape and self[slices].ndim != value.ndim: value = value.squeeze() - + return origin_setitem(self, slices, value) Tensor.__setitem__ = __setitem__ @@ -362,16 +356,8 @@ def nbytes(self): Tensor.unsqueeze = ops.unsqueeze StubTensor.unsqueeze = ops.unsqueeze - def log_softmax(self, dim=-1, dtype=None): - if use_pyboost(): - return mindspore.mint.nn.functional.log_softmax(self, dim=dim, dtype=dtype) - out = mindspore.ops.log_softmax(self, dim) - if dtype is not None: - out = out.to(dtype) - return out - - Tensor.log_softmax = log_softmax - StubTensor.log_softmax = log_softmax + Tensor.log_softmax = ops.log_softmax + StubTensor.log_softmax = ops.log_softmax def untyped_storage(self): return UntypedStorage(self) @@ -457,7 +443,7 @@ def unfold(self, dimension, size, step): def new(self, *shape): if not isinstance(shape[0], int): return tensor(shape[0], dtype=self.dtype) - return ops.empty(*shape, dtype=self.dtype) + return ops.empty(*shape, dtype=self.dtype, device=self.device) Tensor.new = new StubTensor.new = new @@ -638,9 +624,26 @@ def new_tensor(self, data, *, dtype=None, device=None, requires_grad=False, layo Tensor.fill_diagonal_ = ops.inplace_fill_diagonal StubTensor.fill_diagonal_ = ops.inplace_fill_diagonal + Tensor.fill_ = ops.inplace_fill + StubTensor.fill_ = ops.inplace_fill + + Tensor.zero_ = ops.inplace_zero + StubTensor.zero_ = ops.inplace_zero + + Tensor.uniform_ = ops.inplace_uniform + StubTensor.uniform_ = ops.inplace_uniform + + Tensor.random_ = ops.inplace_random + StubTensor.random_ = ops.inplace_random + + Tensor.triu_ = ops.inplace_triu StubTensor.triu_ = ops.inplace_triu + Tensor.masked_fill_ = ops.inplace_masked_fill + StubTensor.masked_fill_ = ops.inplace_masked_fill + + @property def real(self): return ops.real(self) @@ -780,6 +783,15 @@ def tobytes(self): Tensor.cuda = cpu StubTensor.cuda = cpu + Tensor.nonzero = ops.nonzero + StubTensor.nonzero = ops.nonzero + + Tensor.clamp_ = ops.inplace_clamp + StubTensor.clamp_ = ops.inplace_clamp + + Tensor.copy_ = ops.inplace_copy + StubTensor.copy_ = ops.inplace_copy + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) return ret \ No newline at end of file diff --git a/mindnlp/core/cuda/__init__.py b/mindnlp/core/cuda/__init__.py index 2d72a8d48..a1da3c032 100644 --- a/mindnlp/core/cuda/__init__.py +++ b/mindnlp/core/cuda/__init__.py @@ -3,6 +3,8 @@ import mindspore from mindspore import get_rng_state, set_rng_state, manual_seed from mindspore.hal import * +from mindspore.runtime import memory_reserved as ms_memory_reserved, \ + memory_allocated as ms_memory_allocated from mindnlp import core @@ -42,4 +44,16 @@ def __enter__(self): def __exit__(self, type: Any, value: Any, traceback: Any): return False -OutOfMemoryError = RuntimeError \ No newline at end of file +OutOfMemoryError = RuntimeError + +def is_bf16_supported(): + return False + +def mem_get_info(index): + return (1024, 1024) + +def memory_reserved(device=None): + return ms_memory_reserved() + +def memory_allocated(device=None): + return ms_memory_allocated() \ No newline at end of file diff --git a/mindnlp/core/distributed/tensor/__init__.py b/mindnlp/core/distributed/tensor/__init__.py index 9790c03e8..e4136ff5c 100644 --- a/mindnlp/core/distributed/tensor/__init__.py +++ b/mindnlp/core/distributed/tensor/__init__.py @@ -1,4 +1,6 @@ -Replicate = None +class Replicate: pass +class Shard: + def __init__(self, *args, **kargs): + pass class DTensor(): pass Placement = None -Shard = None diff --git a/mindnlp/core/fx/_compatibility.py b/mindnlp/core/fx/_compatibility.py index e69de29bb..26bb3ff3b 100644 --- a/mindnlp/core/fx/_compatibility.py +++ b/mindnlp/core/fx/_compatibility.py @@ -0,0 +1,39 @@ +import textwrap +from typing import Any, Callable, TypeVar + + +_BACK_COMPAT_OBJECTS: dict[Any, None] = {} +_MARKED_WITH_COMPATIBILITY: dict[Any, None] = {} + + +_T = TypeVar("_T") + + +def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]: + if is_backward_compatible: + + def mark_back_compat(fn: _T) -> _T: + docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "") + docstring += """ +.. note:: + Backwards-compatibility for this API is guaranteed. +""" + fn.__doc__ = docstring + _BACK_COMPAT_OBJECTS.setdefault(fn) + _MARKED_WITH_COMPATIBILITY.setdefault(fn) + return fn + + return mark_back_compat + else: + + def mark_not_back_compat(fn: _T) -> _T: + docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "") + docstring += """ +.. warning:: + This API is experimental and is *NOT* backward-compatible. +""" + fn.__doc__ = docstring + _MARKED_WITH_COMPATIBILITY.setdefault(fn) + return fn + + return mark_not_back_compat diff --git a/mindnlp/core/fx/_symbolic_trace.py b/mindnlp/core/fx/_symbolic_trace.py index 2e9d28668..f712d1a3c 100644 --- a/mindnlp/core/fx/_symbolic_trace.py +++ b/mindnlp/core/fx/_symbolic_trace.py @@ -12,6 +12,11 @@ from typing import Any, Callable, NamedTuple, Optional, Union _wrapped_fns_to_patch: dict[tuple[int, str], dict] = {} +_is_fx_tracing_flag = False + + +def is_fx_tracing(): + return _is_fx_tracing_flag def wrap(fn_or_name: Union[str, Callable]): """ diff --git a/mindnlp/core/fx/proxy.py b/mindnlp/core/fx/proxy.py index 46b350813..2d653b69f 100644 --- a/mindnlp/core/fx/proxy.py +++ b/mindnlp/core/fx/proxy.py @@ -1,2 +1,5 @@ class Proxy: pass + +class ParameterProxy(Proxy): + pass diff --git a/mindnlp/core/nn/attention/flex_attention.py b/mindnlp/core/nn/attention/flex_attention.py index f9205be20..c2db2c1e4 100644 --- a/mindnlp/core/nn/attention/flex_attention.py +++ b/mindnlp/core/nn/attention/flex_attention.py @@ -1,3 +1,4 @@ BlockMask = None flex_attention = None create_block_mask = None +_DEFAULT_SPARSE_BLOCK_SIZE = None \ No newline at end of file diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index 70f4f5b0d..22b41035f 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -352,8 +352,75 @@ def pad(input, pad, mode='constant', value=None): return ops.pad(input, new_pad, mode, value) def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean'): - return _nllloss_nd(input, target, weight, ignore_index, reduction) + if input.device.type == 'npu': + return _nllloss_nd(input, target, weight, ignore_index, reduction) + return _inner_nll_loss(input, target, weight, ignore_index, reduction) + +def _inner_nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): + ndim = inputs.ndim + if ndim == 2: + ret = _nll_loss(inputs, target, -1, weight, ignore_index, reduction, label_smoothing) + elif ndim == 4: + ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing) + elif ndim == 1: + ret = _nll_loss(inputs, target, 0, weight, ignore_index, reduction, label_smoothing) + else: + n = inputs.shape[0] + c = inputs.shape[1] + out_size = (n,) + inputs.shape[2:] + inputs = inputs.view((n, c, 1, -1)) + target = target.view((n, 1, -1)) + if reduction != 'none': + ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing) + else: + ret = _nll_loss(inputs, target, 1, weight, ignore_index, label_smoothing=label_smoothing) + ret = ret.view(out_size) + return ret + +def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0): + """nll loss inner function""" + if target.ndim == inputs.ndim - 1: + target = target.expand_dims(target_dim) + if ignore_index is not None: + non_pad_mask = ops.equal(target, ignore_index) + target = target.masked_fill(non_pad_mask, ops.cast(0, target.dtype)) + else: + non_pad_mask = target + if weight is not None: + loss_weights = ops.gather(weight, target, 0) + orig_shape = inputs.shape + if inputs.ndim != 2: + inputs = inputs.view(orig_shape[:2] + (-1,)) + weight = weight.view(weight.shape + (1,)) + weighted_inputs = inputs * weight + weighted_inputs = weighted_inputs.view(orig_shape) + loss = ops.neg(ops.gather_d(weighted_inputs, target_dim, target)) + smooth_loss = ops.neg(weighted_inputs.sum(axis=target_dim, keepdims=True)) + else: + loss = ops.neg(ops.gather_d(inputs, target_dim, target)) + smooth_loss = ops.neg(inputs.sum(axis=target_dim, keepdims=True)) + loss_weights = ops.ones_like(loss) + + if ignore_index is not None: + loss = loss.masked_fill(non_pad_mask, ops.cast(0, loss.dtype)) + loss_weights = loss_weights.masked_fill(non_pad_mask, ops.cast(0, loss_weights.dtype)) + smooth_loss = smooth_loss.masked_fill(non_pad_mask, ops.cast(0, smooth_loss.dtype)) + loss = loss.squeeze(target_dim) + smooth_loss = smooth_loss.squeeze(target_dim) + + if reduction == 'sum': + loss = loss.sum() + smooth_loss = smooth_loss.sum() + if reduction == 'mean': + loss = loss.sum() / loss_weights.sum() + smooth_loss = smooth_loss.sum() / loss_weights.sum() + + eps_i = label_smoothing / inputs.shape[target_dim] + if label_smoothing != 0: + loss = (1. - label_smoothing) * loss + eps_i * smooth_loss + + return loss def _nllloss_nd(input, target, weight=None, ingore_index=-100, reduction='mean'): """nllloss_nd inner function""" @@ -402,8 +469,11 @@ def cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean return _cross_entropy_for_probabilities(input, target, weight, reduction, label_smoothing, class_dim, n_classes) # for class indices - return _cross_entropy_for_class_indices(input, target, weight, ignore_index, reduction, label_smoothing, - class_dim, n_classes) + if input.device.type == 'npu': + return _cross_entropy_for_class_indices(input, target, weight, ignore_index, reduction, label_smoothing, + class_dim, n_classes) + return _inner_nll_loss(log_softmax(input, class_dim), target, weight, ignore_index, reduction, label_smoothing) + def _cross_entropy_for_probabilities(input, target, weight, reduction, label_smoothing, class_dim, n_classes): """cross_entropy inner function for class probabilities""" diff --git a/mindnlp/core/nn/parameter.py b/mindnlp/core/nn/parameter.py index 4393ffb99..1fb06de52 100644 --- a/mindnlp/core/nn/parameter.py +++ b/mindnlp/core/nn/parameter.py @@ -72,6 +72,8 @@ def requires_grad(self, value): self.handle.remove() delattr(self, 'handle') + def retain_grad(self): + pass class UninitializedParameter(Parameter): def __init__(self, input_data=None, requires_grad=True): diff --git a/mindnlp/core/ops/comparison.py b/mindnlp/core/ops/comparison.py index c41b8a232..78cd343e0 100644 --- a/mindnlp/core/ops/comparison.py +++ b/mindnlp/core/ops/comparison.py @@ -7,7 +7,7 @@ from ._inner import call_ms_func -sort_out = namedtuple('sort_out', ['sorted', 'indices']) +sort_out = namedtuple('sort_out', ['values', 'indices']) topk_out = namedtuple('topk_out', ['values', 'indices']) # allclose has_allclose = hasattr(mindspore.mint, 'allclose') @@ -30,6 +30,8 @@ def argsort(input, dim=-1, descending=False, stable=False): def eq(input, other, *, out=None): if use_pyboost() and has_eq: return call_ms_func(mindspore.mint.eq, input, other, out=out) + if isinstance(other, str): + return False return call_ms_func(ops.eq, input, other, out=out) # equal @@ -176,7 +178,7 @@ def sort(input, *, dim=-1, descending=False, stable=False): out = mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable) else: out = ops.sort(input, dim, descending) - return sort_out(sorted=out[0], indices=out[1]) + return sort_out(values=out[0], indices=out[1]) # topk has_topk = hasattr(mindspore.mint, 'topk') diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py index 848fb625c..91e07f539 100644 --- a/mindnlp/core/ops/creation.py +++ b/mindnlp/core/ops/creation.py @@ -121,6 +121,9 @@ def arange(start=0, end=None, step=1, *, dtype=None, device=None): step = step.item() if isinstance(step, (mindspore.Tensor, np.integer)) else step return mindspore.mint.arange(start, end, step, dtype=dtype) + if end is None: + end = start + start = 0 start = mindspore.Tensor(start) if not isinstance(start, mindspore.Tensor) else start end = mindspore.Tensor(end) if not isinstance(start, mindspore.Tensor) else end step = mindspore.Tensor(step) if not isinstance(start, mindspore.Tensor) else step @@ -183,18 +186,17 @@ def empty(*size, dtype=None, device=None, requires_grad=False, pin_memory=False, if not isinstance(device, str) and hasattr(device, "type"): device = device.type if device.lower() == 'cpu': - device = device.upper() - elif device.lower() in ['cuda', 'npu']: + device = 'CPU' + elif device.lower() == 'npu': device = 'Ascend' else: - device = 'CPU' + device = 'GPU' # To avoid the problem in irecv and recv of using empty. - - if has_empty: + if has_empty and use_pyboost(): out = mindspore.mint.empty(size, dtype=dtype, device=device) else: - out = CTensor(dtype, size) + out = CTensor(dtype=dtype, shape=size) out = mindspore.Tensor(out) if requires_grad: out.requires_grad = True diff --git a/mindnlp/core/ops/inplace.py b/mindnlp/core/ops/inplace.py index a5dbef4f2..fa53de7e1 100644 --- a/mindnlp/core/ops/inplace.py +++ b/mindnlp/core/ops/inplace.py @@ -1,10 +1,13 @@ import numbers import mindspore from mindspore import ops +from mindspore._c_expression import typing from mindspore.ops._primitive_cache import _get_cache_prim from mindspore.common.generator import default_generator from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op, inplace_scatter_value_op, inplace_scatter_src_reduce_op, \ - inplace_scatter_src_op + inplace_scatter_src_op, inplace_fill_tensor_op, inplace_fill_scalar_op, inplace_zero_op, inplace_uniform_op, \ + inplace_masked_fill_scalar_op, inplace_masked_fill_tensor_op, inplace_random_op, inplace_clamp_scalar_op, \ + inplace_clamp_tensor_op, inplace_copy_op from mindnlp import core from ..configs import use_pyboost @@ -13,34 +16,27 @@ generator_step_ = 12 def inplace_copy(self, other): - if self.device != other.device: - other = other.to(self.device) - if self.device.type == 'cpu': - # execute('assign', self, other) - # # self._data.assign_value_cpp(other._data) - self.data = other + if self.device.type == 'npu': + inplace_copy_op(self, other) else: - execute('inplace_copy', self, other) + self.data = other return self def inplace_zero(input): - device = input.device if input.device == 'npu': - execute('inplace_zero', input) - elif input.device.type == 'cpu': - out = execute('zeros', input.shape, input.dtype, device=device) - input.data = out + inplace_zero_op(input) + else: + input.data = ops.zeros(input.shape, dtype=input.dtype) return input def inplace_fill(input, value): - device = input.device - if input.device == 'npu': + if input.device.type == 'npu': if isinstance(value, (int, float, bool)): - execute('inplace_fill_scalar', input, value) - execute('inplace_fill_tensor', input, value) - elif input.device.type == 'cpu': - out = execute('full', input.shape, value, device=device) - input.data = out + inplace_fill_scalar_op(input, value) + else: + inplace_fill_tensor_op(input, value) + else: + input.data = ops.full(input.shape, value, dtype=input.dtype) return input def inplace_normal(input, mean=0, std=1, *, generator=None): @@ -51,8 +47,10 @@ def inplace_normal(input, mean=0, std=1, *, generator=None): mean = mean.item() if isinstance(std, core.Tensor): std = std.item() - inplace_normal_op(input, mean, std, seed, offset) - + if input.device.type == 'npu': + inplace_normal_op(input, mean, std, seed, offset) + else: + input.data = ops.normal(input.shape, mean, std) return input # uniform_ @@ -77,8 +75,8 @@ def inplace_uniform(input, *args, **kwargs): generator_ = default_generator seed, offset = generator_._step(generator_step_) if input.device.type == 'npu': - execute("inplace_uniform", input, from_, to_, seed, offset) - elif input.device.type == 'cpu': + inplace_uniform_op(input, from_, to_, seed, offset) + else: input.data = core.rand(input.shape, generator=generator_, dtype=input.dtype) * (to_ - from_) + from_ return input @@ -189,6 +187,42 @@ def inplace_tril(self, diagonal=0): self.data = core.tril(self, diagonal) return self +def inplace_masked_fill(self, mask, value): + if self.device.type == 'npu': + if isinstance(value, (int, float, bool)): + inplace_masked_fill_scalar_op(self, mask, value) + else: + inplace_masked_fill_tensor_op(self, mask, value) + else: + self.data = ops.masked_fill(self, mask, value) + return self + +def inplace_random(self, from_=0, to=None, *, generator=None): + if self.device.type == 'npu': + if not generator: + generator = default_generator + seed, offset = generator._step( # pylint: disable=protected-access + generator_step_) + return inplace_random_op(input, from_, to, seed, offset) + else: + if isinstance(self.dtype, typing.Float): + self.uniform_(from_, to, generator=generator) + elif isinstance(self.dtype, typing.Int): + if to is None: + to = core.iinfo(mindspore.int32).max + self.data = core.randint(from_, to, size=self.shape, dtype=self.dtype) + return self + +def inplace_clamp(self, min=None, max=None): + if self.device.type == 'npu': + if isinstance(min, (int, float, bool)) or isinstance(max, (int, float, bool)): + inplace_clamp_scalar_op(self, min, max) + else: + inplace_clamp_tensor_op(self, min, max) + else: + self.data = ops.clamp(self, min, max) + return self + __all__ = [ 'inplace_copy', 'inplace_zero', @@ -212,5 +246,8 @@ def inplace_tril(self, diagonal=0): 'inplace_exp', 'inplace_sub', 'inplace_bernoulli', - 'inplace_tril' + 'inplace_tril', + 'inplace_masked_fill', + 'inplace_random', + 'inplace_clamp' ] diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index 96e980612..57189f5bb 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -26,8 +26,17 @@ def bincount(input, weights=None, minlength=0): if use_pyboost() and has_bincount: return mindspore.mint.bincount(input, weights, minlength) - return ops.bincount(input, weights, minlength) - + if input.max() > minlength - 1: + length = (input.max() + 1) + else: + length = core.tensor(minlength) + idx = core.arange(length).unsqueeze(-1) + idx_mapping = core.eq(input, idx) + if weights is not None: + if input.shape != weights.shape: + raise ValueError('for bincount `input` and `weights` must have the same length') + idx_mapping = weights * idx_mapping + return core.sum(idx_mapping, 1).ravel() # block_diag @@ -906,6 +915,13 @@ def __init__(self, bits, min, max, eps, tiny, smallest_normal, resolution, dtype self.resolution = resolution self.dtype = dtype +class iinfo: + def __init__(self, bits, min, max, dtype): + self.bits = bits + self.min = min + self.max = max + self.dtype = dtype + finfo_dtype = { mindspore.bfloat16: finfo( @@ -955,8 +971,13 @@ def finfo(dtype): return finfo_dtype[dtype] +iinfo_dtype = { + mindspore.int64: iinfo(bits=64, min=-9223372036854775808, max=9223372036854775807, dtype='int64'), + mindspore.int32: iinfo(bits=32, min=-2147483648, max=2147483647, dtype='int32') +} + def iinfo(dtype): - return np.iinfo(mindspore.dtype_to_nptype(dtype)) + return iinfo_dtype[dtype] def contains(self, key): diff --git a/mindnlp/core/ops/pointwise.py b/mindnlp/core/ops/pointwise.py index e226311ee..f6885f288 100644 --- a/mindnlp/core/ops/pointwise.py +++ b/mindnlp/core/ops/pointwise.py @@ -64,6 +64,8 @@ def add(input, other, *, alpha=1, out=None): return call_ms_func(mindspore.mint.add, input, other, alpha=alpha, out=out) if alpha != 1: other = mul(alpha, other) + if input.dtype == mindspore.bool_: + return ops.add(input.int(), other.int()).bool() return call_ms_func(ops.add, input, other, out=out) @@ -581,9 +583,15 @@ def igammac(input, other): def mul(input, other, *, out=None): if use_pyboost() and has_mul: - out = call_ms_func(mindspore.mint.mul, input, other, out=out) + out = mindspore.mint.mul(input, other) else: - out = call_ms_func(ops.mul, input, other, out=out) + if input.dtype == mindspore.bool_: + if isinstance(other, bool): + out = ops.bitwise_and(input, other) + else: + out = ops.mul(input.int(), other).bool() + else: + out = ops.mul(input, other) if isinstance(other, mindspore.Tensor): out_dtype = min(input.dtype, other.dtype) diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py index e8d19c20d..0ad1b7064 100644 --- a/mindnlp/core/ops/reduction.py +++ b/mindnlp/core/ops/reduction.py @@ -68,7 +68,7 @@ def any(input, dim=None, keepdim=False, *, out=None): return call_ms_func(mindspore.mint.any, input, out=out) else: return call_ms_func(mindspore.mint.any, input, dim, keepdim, out=out) - return call_ms_func(ops.any, input, dim, out=out) + return ops.any(input, dim, keepdim) # max has_max = hasattr(mindspore.mint, 'max') @@ -139,7 +139,9 @@ def nanmedian(input, dim=-1, keepdim=False): def norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None): if use_pyboost() and has_norm: return call_ms_func(mindspore.mint.norm, input, p, dim, keepdim, out=out, dtype=dtype) - return call_ms_func(ops.norm, input, p, dim, keepdim, out=out, dtype=dtype) + if p == 'fro': + p = None + return ops.norm(input, p, dim, keepdim, dtype=dtype) # nansum has_nansum = hasattr(mindspore.mint, 'nansum') diff --git a/mindnlp/core/profiler/__init__.py b/mindnlp/core/profiler/__init__.py index 379e8c3e6..270353c36 100644 --- a/mindnlp/core/profiler/__init__.py +++ b/mindnlp/core/profiler/__init__.py @@ -1,3 +1,13 @@ -from mindspore.profiler import ProfilerActivity - +from contextlib import contextmanager from .profiler import profile, tensorboard_trace_handler +from .scheduler import Schedule as schedule +from .experimental_config import AiCMetrics, ProfilerLevel, _ExperimentalConfig, ExportType +from .common import ProfilerActivity + +__all__ = ["profile", "ProfilerActivity", "tensorboard_trace_handler", "schedule", + "_ExperimentalConfig", "ProfilerLevel", "AiCMetrics", "ExportType"] + + +@contextmanager +def record_function(name): + yield \ No newline at end of file diff --git a/mindnlp/core/profiler/common.py b/mindnlp/core/profiler/common.py new file mode 100644 index 000000000..32cf10516 --- /dev/null +++ b/mindnlp/core/profiler/common.py @@ -0,0 +1,9 @@ +from enum import Enum + +class ProfilerActivity(Enum): + """The profiler activity enum.""" + + NPU = "NPU" + GPU = "GPU" + CPU = "CPU" + CUDA = "GPU" \ No newline at end of file diff --git a/mindnlp/core/types.py b/mindnlp/core/types.py index acfc2ebba..8bd350681 100644 --- a/mindnlp/core/types.py +++ b/mindnlp/core/types.py @@ -6,11 +6,18 @@ int as _int, str as _str, ) +import mindspore from typing import Any, IO, TYPE_CHECKING, Union, Dict from typing_extensions import Self, TypeAlias from ._dtype import dtype +DEVICE_MAP = { + 'GPU': 'cuda', + 'Ascend': 'npu', + 'CPU': 'cpu' +} + class device(): def __init__(self, type=None, index=None): if type is not None: @@ -29,6 +36,9 @@ def __init__(self, type=None, index=None): raise ValueError("core.device(): When input is core.device, `index` can not be set.") _target = type.type _id = type.index + elif isinstance(type, int): + _id = type + _target = DEVICE_MAP[mindspore.get_current_device().device_target] else: print(type) raise TypeError("core.device(): `type` must be type of 'str' or 'core.device'.") diff --git a/mindnlp/utils/torch_proxy.py b/mindnlp/utils/torch_proxy.py index 8c0c6ec32..81950036e 100644 --- a/mindnlp/utils/torch_proxy.py +++ b/mindnlp/utils/torch_proxy.py @@ -6,6 +6,8 @@ import importlib.machinery from types import ModuleType +TORCH_VERSION = '2.7.1+dev' + class RedirectFinder(importlib.abc.MetaPathFinder): def __init__(self, redirect_map): # 重定向规则:被代理模块 -> 实际模块 @@ -97,7 +99,7 @@ def __setattr__(_, name, value): def initialize_torch_proxy(): sys.meta_path.insert(0, RedirectFinder(REDIRECT_MAP)) import torch - torch.__version__ = "2.1.1+dev" + torch.__version__ = TORCH_VERSION def setup_metadata_patch(): @@ -110,10 +112,10 @@ def setup_metadata_patch(): def patched_distribution(dist_name): if dist_name == "torch": return types.SimpleNamespace( - version="2.1.1+dev", - metadata={"Name": "torch", "Version": "2.1.1+dev"}, + version=TORCH_VERSION, + metadata={"Name": "torch", "Version": TORCH_VERSION}, read_text=lambda f: ( - f"Name: torch\nVersion: 2.1.1+dev" if f == "METADATA" else None + f"Name: torch\nVersion: {TORCH_VERSION}" if f == "METADATA" else None ), ) return orig_distribution(dist_name) @@ -124,8 +126,8 @@ def patched_distributions(**kwargs): dists.append( types.SimpleNamespace( name="torch", - version="2.1.1+dev", - metadata={"Name": "torch", "Version": "2.1.1+dev"}, + version=TORCH_VERSION, + metadata={"Name": "torch", "Version": TORCH_VERSION}, files=[], locate_file=lambda p: None, _normalized_name="torch", diff --git a/setup.py b/setup.py index 7bbfec1f4..a985f4b64 100644 --- a/setup.py +++ b/setup.py @@ -164,6 +164,7 @@ def run(self): 'evaluate', # hf dependency 'tokenizers', # hf dependency 'safetensors', # hf dependency + 'diffusers', # hf dependency 'sentencepiece', 'regex', 'addict', diff --git a/tests/run_test.py b/tests/run_test.py index beef5c039..f4b6e6d93 100644 --- a/tests/run_test.py +++ b/tests/run_test.py @@ -1,4 +1,7 @@ +import os import sys +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + import pytest import mindspore import mindnlp @@ -23,6 +26,7 @@ def run_tests(): "and not gradient_checkpointing " \ "and not retain_grad " \ "and not data_parallel " \ + "and not model_parallelism " \ "and not with_static_cache " \ "and not compile " \ "and not compilation " \