From fb993e06dc4e71b413dea89dc02da0df632f2873 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Mon, 21 Jul 2025 12:34:29 +0800 Subject: [PATCH] fix e class --- mindnlp/core/_tensor.py | 19 ++++++++++++++++++- mindnlp/core/nn/functional.py | 1 - 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 505867497..a654c3d3e 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -1,6 +1,6 @@ import math import numpy as np -from functools import partial +import warnings import mindspore from mindspore import Tensor from mindspore.common.tensor import _TensorMeta @@ -379,6 +379,8 @@ def unfold(self, dimension, size, step): StubTensor.unfold = unfold def new(self, *shape): + if not isinstance(shape[0], int): + return tensor(shape[0], dtype=self.dtype) return ops.empty(*shape, dtype=self.dtype) Tensor.new = new @@ -540,6 +542,21 @@ def new_tensor(self, data, *, dtype=None, device=None, requires_grad=False, layo Tensor.triu_ = ops.inplace_triu StubTensor.triu_ = ops.inplace_triu + @property + def real(self): + return ops.real(self) + + Tensor.real = real + StubTensor.real = real + + def bfloat16(self): + if ON_A1: + warnings.warn('910A do not support bfloat16, use float16 instead.') + return self.to(_dtype.float16) + return self.to(_dtype.bfloat16) + + Tensor.bfloat16 = bfloat16 + StubTensor.bfloat16 = bfloat16 def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index fa66b06af..497fca22c 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -166,7 +166,6 @@ def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, coun Returns: - numpy array: The result of the average pooling operation. """ - print(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) if use_pyboost(): return mint.nn.functional.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)