From 3c2a468c6eba4b2d191c2a48451d7adcc75caa98 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Mon, 21 Jul 2025 02:14:38 +0800 Subject: [PATCH] fix e class ut --- mindnlp/core/__init__.py | 6 ++++ mindnlp/core/_tensor.py | 46 +++++++++++++++++++++----- mindnlp/core/cuda/amp/autocast_mode.py | 11 ------ mindnlp/core/linalg/__init__.py | 2 +- mindnlp/core/ops/inplace.py | 17 +++++++++- mindnlp/core/ops/reduction.py | 5 ++- 6 files changed, 65 insertions(+), 22 deletions(-) diff --git a/mindnlp/core/__init__.py b/mindnlp/core/__init__.py index cf140038b..46284f102 100644 --- a/mindnlp/core/__init__.py +++ b/mindnlp/core/__init__.py @@ -96,4 +96,10 @@ def set_autocast_dtype(device_type, dtype): def get_autocast_dtype(device_type): return AUTO_CAST_DTYE[device_type] +def get_autocast_gpu_dtype(): + return AUTO_CAST_DTYE['cuda'] + +def is_autocast_enabled(): + return True + __version__ = 'test_version_no_value' \ No newline at end of file diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 27c62eac8..505867497 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -283,13 +283,13 @@ def __setitem__(self, slices, value): value = ops.finfo(self.dtype).max elif value == -float('inf'): value = ops.finfo(self.dtype).min - # if isinstance(slices, tuple): - # new_slices = () - # for s in slices: - # if isinstance(s, range): - # s = list(s) - # new_slices += (s,) - # slices = new_slices + if isinstance(slices, tuple): + new_slices = () + for s in slices: + if isinstance(s, range): + s = list(s) + new_slices += (s,) + slices = new_slices if not isinstance(value, Tensor): value = tensor(value, dtype=self.dtype) return origin_setitem(self, slices, value) @@ -507,10 +507,40 @@ def __repr__(self): Tensor.__repr__ = __repr__ StubTensor.__repr__ = _stub_method(__repr__) - def detach_(self): return ops.stop_gradient(self) + Tensor.detach_ = detach_ + StubTensor.detach_ = detach_ + + def new_full(self, size, fill_value, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False): + return ops.full(size, fill_value, dtype=dtype if dtype is not None else self.dtype) + + Tensor.new_full = new_full + StubTensor.new_full = new_full + + def new_zeros(self, *size, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False): + return ops.zeros(*size, dtype=dtype if dtype is not None else self.dtype) + + Tensor.new_zeros = new_zeros + StubTensor.new_zeros = new_zeros + + Tensor.sum = ops.sum + StubTensor.sum = ops.sum + + def new_tensor(self, data, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False): + return tensor(data, dtype=dtype if dtype is not None else self.dtype) + + Tensor.new_tensor = new_tensor + StubTensor.new_tensor = new_tensor + + Tensor.fill_diagonal_ = ops.inplace_fill_diagonal + StubTensor.fill_diagonal_ = ops.inplace_fill_diagonal + + Tensor.triu_ = ops.inplace_triu + StubTensor.triu_ = ops.inplace_triu + + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) return ret \ No newline at end of file diff --git a/mindnlp/core/cuda/amp/autocast_mode.py b/mindnlp/core/cuda/amp/autocast_mode.py index cd22ce019..4b344708d 100644 --- a/mindnlp/core/cuda/amp/autocast_mode.py +++ b/mindnlp/core/cuda/amp/autocast_mode.py @@ -26,29 +26,18 @@ def __init__( dtype: core.dtype = core.float16, cache_enabled: bool = True, ): - if core._jit_internal.is_scripting(): - self._enabled = enabled - self.device = "cuda" - self.fast_dtype = dtype - return super().__init__( "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled ) def __enter__(self): - if core._jit_internal.is_scripting(): - return self return super().__enter__() # TODO: discuss a unified TorchScript-friendly API for autocast def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] - if core._jit_internal.is_scripting(): - return return super().__exit__(exc_type, exc_val, exc_tb) def __call__(self, func): - if core._jit_internal.is_scripting(): - return func return super().__call__(func) diff --git a/mindnlp/core/linalg/__init__.py b/mindnlp/core/linalg/__init__.py index 5b56215c4..564171c14 100644 --- a/mindnlp/core/linalg/__init__.py +++ b/mindnlp/core/linalg/__init__.py @@ -22,4 +22,4 @@ def cholesky_ex(A, *, upper=False, check_errors=False, out=None): def norm(A, ord=None, dim=None, keepdim=False, *, out=None, dtype=None): - return mint.norm(A, ord, dim, keepdim, dtype=dtype) + return mint.norm(A, 2 if ord is None else ord, dim, keepdim, dtype=dtype) diff --git a/mindnlp/core/ops/inplace.py b/mindnlp/core/ops/inplace.py index 0dc00ca60..544d83b79 100644 --- a/mindnlp/core/ops/inplace.py +++ b/mindnlp/core/ops/inplace.py @@ -118,6 +118,19 @@ def inplace_unsqueeze(input, dim=None): input.assign_value(out) return input +def inplace_fill_diagonal(input, fill_value, wrap=False): + fill_diagnoal_ = _get_cache_prim(ops.FillDiagonal)(float(fill_value), wrap) + out = fill_diagnoal_(input) + input.assign_value(out) + return input + +def inplace_triu(input, diagonal=0): + out = ops.triu(input, diagonal) + input.assign_value(out) + return input + + + __all__ = [ 'inplace_copy', 'inplace_zero', @@ -129,5 +142,7 @@ def inplace_unsqueeze(input, dim=None): 'inplace_index_copy', 'inplace_index_add', 'inplace_squeeze', - 'inplace_unsqueeze' + 'inplace_unsqueeze', + 'inplace_fill_diagonal', + 'inplace_triu' ] diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py index 808d62b6f..94d73f657 100644 --- a/mindnlp/core/ops/reduction.py +++ b/mindnlp/core/ops/reduction.py @@ -181,7 +181,10 @@ def std_mean(input, dim=None, *, correction=1, keepdim=False): # sum has_sum = hasattr(mindspore.mint, 'sum') -def sum(input, dim=None, keepdim=False, *, dtype=None): +def sum(input, dim=None, keepdim=False, *, dtype=None, **kwargs): + keepdims = kwargs.pop('keepdims', None) + if keepdims is not None: + keepdim = keepdims if 0 in input.shape: return mindspore.tensor(0, dtype=dtype) if use_pyboost() and has_sum: