diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 46dbf3760..c93b284ce 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -312,8 +312,10 @@ def __setitem__(self, slices, value): value = tensor(value, dtype=self.dtype) else: value = value.to(self.dtype) + if 1 in value.shape and self[slices].ndim != value.ndim: value = value.squeeze() + return origin_setitem(self, slices, value) Tensor.__setitem__ = __setitem__ @@ -658,6 +660,21 @@ def __contains__(self, item): Tensor.exponential_ = ops.inplace_exponential StubTensor.exponential_ = ops.inplace_exponential + Tensor.log_ = ops.inplace_log + StubTensor.log_ = ops.inplace_log + + Tensor.mul_ = ops.inplace_mul + StubTensor.mul_ = ops.inplace_mul + + Tensor.neg_ = ops.inplace_neg + StubTensor.neg_ = ops.inplace_neg + + Tensor.exp_ = ops.inplace_exp + StubTensor.exp_ = ops.inplace_exp + + Tensor.sub_ = ops.inplace_sub + StubTensor.sub_ = ops.inplace_sub + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) return ret \ No newline at end of file diff --git a/mindnlp/core/ops/inplace.py b/mindnlp/core/ops/inplace.py index bdfd14ae7..3449cdd0f 100644 --- a/mindnlp/core/ops/inplace.py +++ b/mindnlp/core/ops/inplace.py @@ -1,8 +1,10 @@ +import numbers import mindspore from mindspore import ops from mindspore.ops._primitive_cache import _get_cache_prim from mindspore.common.generator import default_generator -from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op, inplace_scatter_value_op, inplace_scatter_src_reduce_op +from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op, inplace_scatter_value_op, inplace_scatter_src_reduce_op, \ + inplace_scatter_src_op from mindnlp import core from ..configs import use_pyboost @@ -85,6 +87,8 @@ def inplace_add(input, other, alpha): return input def inplace_scatter(input, dim, index, src): + if not isinstance(src, numbers.Number): + return inplace_scatter_src_op(input, dim, index, src) return inplace_scatter_value_op(input, dim, index, src) def inplace_index_copy(input, dim, index, tensor): @@ -157,6 +161,26 @@ def inplace_exponential(tensor, lambd=1.0): return tensor +def inplace_log(self): + self.data = core.log(self) + return self + +def inplace_mul(self, other): + self.data = core.mul(self, other) + return self + +def inplace_neg(self): + self.data = core.neg(self) + return self + +def inplace_exp(self): + self.data = core.exp(self) + return self + +def inplace_sub(self, other): + self.data = core.sub(self, other) + return self + __all__ = [ 'inplace_copy', 'inplace_zero', @@ -173,5 +197,10 @@ def inplace_exponential(tensor, lambd=1.0): 'inplace_triu', 'inplace_round', 'inplace_scatter_reduce', - 'inplace_exponential' + 'inplace_exponential', + 'inplace_log', + 'inplace_mul', + 'inplace_neg', + 'inplace_exp', + 'inplace_sub' ] diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index 089b0bb81..b2d7f464d 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -178,12 +178,10 @@ def cumsum(input, dim=None, dtype=None, out=None, **kwargs): # diag has_diag = hasattr(mindspore.mint, "diag") - - def diag(input, diagonal=0): if use_pyboost() and has_diag: return mindspore.mint.diag(input, diagonal) - return ops.diag(input) + return mindspore.numpy.diag(input, diagonal) # diag_embed @@ -806,6 +804,8 @@ def searchsorted( sorter=None, ): if use_pyboost() and has_searchsorted: + if not isinstance(values, core.Tensor): + values = core.tensor(values) return call_ms_func( mindspore.mint.searchsorted, sorted_sequence, @@ -1030,12 +1030,26 @@ def unfold(input, dimension, size, step): return output +def cartesian_prod(*tensors): + """ + 手动实现 torch.cartesian_prod + :param tensors: 一个或多个一维张量 + :return: 笛卡尔积结果的二维张量 (每行一个组合) + """ + # 生成网格坐标 + grids = core.meshgrid(*tensors, indexing='ij') + + # 展平每个网格张量并堆叠 + return core.stack([g.reshape(-1) for g in grids], dim=1) + + __all__ = [ "bincount", "broadcast_shapes", "broadcast_tensors", "broadcast_to", "bucketize", + "cartesian_prod", "cdist", "clone", "contains", diff --git a/mindnlp/core/ops/pointwise.py b/mindnlp/core/ops/pointwise.py index 501826639..cd59b2743 100644 --- a/mindnlp/core/ops/pointwise.py +++ b/mindnlp/core/ops/pointwise.py @@ -1,4 +1,5 @@ """pointwise op""" + import mindspore from mindspore import ops from ..configs import use_pyboost, ON_A1 @@ -7,41 +8,57 @@ from mindnlp import core # abs -has_abs = hasattr(mindspore.mint, 'abs') +has_abs = hasattr(mindspore.mint, "abs") + + def abs(input, *, out=None): if use_pyboost() and has_abs: return call_ms_func(mindspore.mint.abs, input, out=out) return call_ms_func(ops.abs, input, out=out) + # absolute def absolute(input, *, out=None): return abs(input, out=out) + # acos -has_acos = hasattr(mindspore.mint, 'acos') +has_acos = hasattr(mindspore.mint, "acos") + + def acos(input, *, out=None): if use_pyboost() and has_acos: return call_ms_func(mindspore.mint.acos, input, out=out) return call_ms_func(ops.acos, input, out=out) + # arccos def arrcos(input, out=None): return acos(input, out=out) + # acosh -has_acosh = hasattr(mindspore.mint, 'acosh') +has_acosh = hasattr(mindspore.mint, "acosh") + + def acosh(input, *, out=None): if use_pyboost and has_acosh: return call_ms_func(mindspore.mint.acosh, input, out=out) return call_ms_func(ops.acosh, input, out=out) + # arccosh -has_arccosh = hasattr(mindspore.mint, 'arccosh') +has_arccosh = hasattr(mindspore.mint, "arccosh") + + def arccosh(input): return acosh(input) + # add -has_add = hasattr(mindspore.mint, 'add') +has_add = hasattr(mindspore.mint, "add") + + def add(input, other, *, alpha=1, out=None): if use_pyboost() and has_add: return call_ms_func(mindspore.mint.add, input, other, alpha=alpha, out=out) @@ -49,128 +66,182 @@ def add(input, other, *, alpha=1, out=None): other = mul(alpha, other) return call_ms_func(ops.add, input, other, out=out) + # addcdiv def addcdiv(input, tensor1, tensor2, *, value=1): return ops.addcdiv(input, tensor1, tensor2, value) + # addcmul def addcmul(input, tensor1, tensor2, *, value=1): return ops.addcmul(input, tensor1, tensor2, value) + # angle def angle(input): return ops.angle(input) + # asin -has_asin = hasattr(mindspore.mint, 'asin') +has_asin = hasattr(mindspore.mint, "asin") + + def asin(input, *, out=None): if use_pyboost and has_asin: return call_ms_func(mindspore.mint.asin, input, out=out) return call_ms_func(ops.asin, input, out=out) + # arcsin -has_arcsin = hasattr(mindspore.mint, 'arcsin') +has_arcsin = hasattr(mindspore.mint, "arcsin") + + def arcsin(input, *, out=None): return asin(input, out=out) + # asinh -has_asinh = hasattr(mindspore.mint, 'asinh') +has_asinh = hasattr(mindspore.mint, "asinh") + + def asinh(input, *, out=None): if use_pyboost and has_asinh: return call_ms_func(mindspore.mint.asinh, input, out=out) return call_ms_func(ops.asinh, input, out=out) + # arcsinh -has_arcsinh = hasattr(mindspore.mint, 'arcsinh') +has_arcsinh = hasattr(mindspore.mint, "arcsinh") + + def arcsinh(input, *, out=None): return asinh(input, out=out) + # atan -has_atan = hasattr(mindspore.mint, 'atan') +has_atan = hasattr(mindspore.mint, "atan") + + def atan(input, *, out=None): if use_pyboost and has_atan: return call_ms_func(mindspore.mint.atan, input, out=out) return call_ms_func(ops.atan, input, out=out) + # arctan -has_arctan = hasattr(mindspore.mint, 'arctan') +has_arctan = hasattr(mindspore.mint, "arctan") + + def arctan(input, *, out=None): return atan(input, out=out) + # atanh -has_atanh = hasattr(mindspore.mint, 'atanh') +has_atanh = hasattr(mindspore.mint, "atanh") + + def atanh(input, *, out=None): if use_pyboost and has_atanh: return call_ms_func(mindspore.mint.atanh, input, out=out) return call_ms_func(ops.atanh, input, out=out) + # arctanh -has_arctanh = hasattr(mindspore.mint, 'arctanh') +has_arctanh = hasattr(mindspore.mint, "arctanh") + + def arctanh(input, *, out=None): return atanh(input, out=out) + # atan2 -has_atan2 = hasattr(mindspore.mint, 'atan2') +has_atan2 = hasattr(mindspore.mint, "atan2") + + def atan2(input, other, *, out=None): if use_pyboost() and has_atan2: return call_ms_func(mindspore.mint.atan2, input, other, out=out) return call_ms_func(ops.atan2, input, other, out=out) + # arctan2 -has_arctan2 = hasattr(mindspore.mint, 'arctan2') +has_arctan2 = hasattr(mindspore.mint, "arctan2") + + def arctan2(input, other, out=None): return atan2(input, other, out=out) + # bitwise_not # bitwise_and -has_bitwise_and = hasattr(mindspore.mint, 'bitwise_and') +has_bitwise_and = hasattr(mindspore.mint, "bitwise_and") + + def bitwise_and(input, other, *, out=None): if use_pyboost() and has_bitwise_and: return call_ms_func(mindspore.mint.bitwise_and, input, other, out=out) return call_ms_func(ops.bitwise_and, input, other, out=out) + # bitwise_or -has_bitwise_or = hasattr(mindspore.mint, 'bitwise_or') +has_bitwise_or = hasattr(mindspore.mint, "bitwise_or") + + def bitwise_or(input, other, *, out=None): if use_pyboost() and has_bitwise_or: return call_ms_func(mindspore.mint.bitwise_or, input, other, out=out) return call_ms_func(ops.bitwise_or, input, other, out=out) + # bitwise_xor -has_bitwise_xor = hasattr(mindspore.mint, 'bitwise_xor') +has_bitwise_xor = hasattr(mindspore.mint, "bitwise_xor") + + def bitwise_xor(input, other, *, out=None): if use_pyboost() and has_bitwise_xor: return call_ms_func(mindspore.mint.bitwise_xor, input, other, out=out) return call_ms_func(ops.bitwise_xor, input, other, out=out) + # bitwise_left_shift def bitwise_left_shift(input, other): return ops.bitwise_left_shift(input, other) + # bitwise_right_shift def bitwise_right_shift(input, other): return ops.bitwise_right_shift(input, other) + # ceil -has_ceil = hasattr(mindspore.mint, 'ceil') +has_ceil = hasattr(mindspore.mint, "ceil") + + def ceil(input, *, out=None): if use_pyboost() and has_ceil: return call_ms_func(mindspore.mint.ceil, input, out=out) return call_ms_func(ops.ceil, input, out=out) + # clamp -has_clamp = hasattr(mindspore.mint, 'clamp') +has_clamp = hasattr(mindspore.mint, "clamp") + + def clamp(input, min=None, max=None, *, out=None): if use_pyboost() and has_clamp: return call_ms_func(mindspore.mint.clamp, input, min, max, out=out) return call_ms_func(ops.clamp, input, min, max, out=out) + # clip -has_clip = hasattr(mindspore.mint, 'clip') +has_clip = hasattr(mindspore.mint, "clip") + + def clip(input, min=None, max=None): return clamp(input, min, max) + # conj_physical @@ -178,55 +249,79 @@ def clip(input, min=None, max=None): # cos -has_cos = hasattr(mindspore.mint, 'cos') +has_cos = hasattr(mindspore.mint, "cos") + + def cos(input, *, out=None): if use_pyboost() and has_cos: return call_ms_func(mindspore.mint.cos, input, out=out) return call_ms_func(ops.cos, input, out=out) + # cosh -has_cosh = hasattr(mindspore.mint, 'cosh') +has_cosh = hasattr(mindspore.mint, "cosh") + + def cosh(input, *, out=None): if use_pyboost() and has_cosh: return call_ms_func(mindspore.mint.cosh, input, out=out) return call_ms_func(ops.cosh, input, out=out) + # deg2rad def deg2rad(input): return ops.deg2rad(input) + # div -has_div = hasattr(mindspore.mint, 'div') +has_div = hasattr(mindspore.mint, "div") + + def div(input, other, *, rounding_mode=None, out=None): if use_pyboost() and has_div: - return call_ms_func(mindspore.mint.div, input, other, rounding_mode=rounding_mode, out=out) + return call_ms_func( + mindspore.mint.div, input, other, rounding_mode=rounding_mode, out=out + ) return call_ms_func(ops.div, input, other, rounding_mode=rounding_mode, out=out) + # divide -has_divide = hasattr(mindspore.mint, 'divide') +has_divide = hasattr(mindspore.mint, "divide") + + def divide(input, other, rounding_mode=None): return div(input, other, rounding_mode=rounding_mode) + # digamma def digamma(input): return ops.digamma(input) + # erf -has_erf = hasattr(mindspore.mint, 'erf') +has_erf = hasattr(mindspore.mint, "erf") + + def erf(input, *, out=None): if use_pyboost() and has_erf: return call_ms_func(mindspore.mint.erf, input, out=out) return call_ms_func(ops.erf, input, out=out) + # erfc -has_erfc = hasattr(mindspore.mint, 'erfc') +has_erfc = hasattr(mindspore.mint, "erfc") + + def erfc(input, *, out=None): if use_pyboost() and has_erfc: return call_ms_func(mindspore.mint.erfc, input, out=out) return call_ms_func(ops.erfc, input, out=out) + # erfinv -has_erfinv = hasattr(mindspore.mint, 'erfinv') +has_erfinv = hasattr(mindspore.mint, "erfinv") + + def erfinv(input, *, out=None): if use_pyboost() and has_erfinv: return call_ms_func(mindspore.mint.erfinv, input, out=out) @@ -234,8 +329,10 @@ def erfinv(input, *, out=None): # exp -has_exp = hasattr(mindspore.mint, 'exp') -has_inplace_exp = hasattr(mindspore.Tensor, 'exp_') +has_exp = hasattr(mindspore.mint, "exp") +has_inplace_exp = hasattr(mindspore.Tensor, "exp_") + + def exp(input, out=None): if has_inplace_exp: return inplace_exp(input, out) @@ -250,6 +347,7 @@ def exp(input, out=None): else: return output + def inplace_exp(input, out=None): if out is None: if use_pyboost() and has_exp: @@ -264,20 +362,27 @@ def inplace_exp(input, out=None): out.copy_(input) return out.exp_() + # exp2 -has_exp2 = hasattr(mindspore.mint, 'exp2') +has_exp2 = hasattr(mindspore.mint, "exp2") + + def exp2(input): if use_pyboost() and has_exp2: return mindspore.mint.exp2(input) return pow(2, input) + # expm1 -has_expm1 = hasattr(mindspore.mint, 'expm1') +has_expm1 = hasattr(mindspore.mint, "expm1") + + def expm1(input, *, out=None): if use_pyboost() and has_expm1: return call_ms_func(mindspore.mint.expm1, input, out=out) return call_ms_func(ops.expm1, input, out=out) + # fake_quantize_per_channel_affine @@ -288,37 +393,50 @@ def expm1(input, *, out=None): # float_power -has_float_power = hasattr(mindspore.mint, 'float_power') +has_float_power = hasattr(mindspore.mint, "float_power") + + def float_power(input, exponent): if use_pyboost() and has_float_power: return mindspore.mint.float_power(input, exponent) return ops.float_power(input, exponent) + # floor -has_floor = hasattr(mindspore.mint, 'floor') +has_floor = hasattr(mindspore.mint, "floor") + + def floor(input, *, out=None): if use_pyboost() and has_floor: return call_ms_func(mindspore.mint.floor, input, out=out) return call_ms_func(ops.floor, input, out=out) + # floor_divide def floor_divide(input, other): return ops.floor_divide(input, other) + # fmod -has_fmod = hasattr(mindspore.mint, 'fmod') +has_fmod = hasattr(mindspore.mint, "fmod") + + def fmod(input, other): if use_pyboost() and has_fmod: return mindspore.mint.fmod(input, other) return ops.fmod(input, other) + # frac -has_frac = hasattr(mindspore.mint, 'frac') +has_frac = hasattr(mindspore.mint, "frac") + + def frac(input): if use_pyboost() and has_frac: return mindspore.mint.frac(input) return fmod(input, 1) + # frexp @@ -326,43 +444,57 @@ def frac(input): def imag(input): return ops.imag(input) + # ldexp # lerp -has_lerp = hasattr(mindspore.mint, 'lerp') +has_lerp = hasattr(mindspore.mint, "lerp") + + def lerp(input, end, weight): if use_pyboost() and has_lerp: return mindspore.mint.lerp(input, end, weight) return ops.lerp(input, end, weight) + # lgamma def lgamma(input): return ops.lgamma(input) + # log -has_log = hasattr(mindspore.mint, 'log') +has_log = hasattr(mindspore.mint, "log") + + def log(input, *, out=None): if use_pyboost() and has_log: return call_ms_func(mindspore.mint.log, input, out=out) return call_ms_func(ops.log, input, out=out) + # log10 # log1p -has_log1p = hasattr(mindspore.mint, 'log1p') +has_log1p = hasattr(mindspore.mint, "log1p") + + def log1p(input, *, out=None): if use_pyboost() and has_log1p: return call_ms_func(mindspore.mint.log1p, input, out=out) return call_ms_func(ops.log1p, input, out=out) + # log2 -has_log2 = hasattr(mindspore.mint, 'log2') +has_log2 = hasattr(mindspore.mint, "log2") + + def log2(input): if use_pyboost() and has_log2: return mindspore.mint.log2(input) return ops.log2(input) + # logaddexp @@ -370,71 +502,97 @@ def log2(input): # logical_and -has_logical_and = hasattr(mindspore.mint, 'logical_and') +has_logical_and = hasattr(mindspore.mint, "logical_and") + + def logical_and(input, other, *, out=None): if use_pyboost() and has_logical_and: return call_ms_func(mindspore.mint.logical_and, input, other, out=out) return call_ms_func(ops.logical_and, input, other, out=out) + # logical_not -has_logical_not = hasattr(mindspore.mint, 'logical_not') +has_logical_not = hasattr(mindspore.mint, "logical_not") + + def logical_not(input, *, out=None): if use_pyboost() and has_logical_not: return call_ms_func(mindspore.mint.logical_not, input, out=out) return call_ms_func(ops.logical_not, input, out=out) + # logical_or -has_logical_or = hasattr(mindspore.mint, 'logical_or') +has_logical_or = hasattr(mindspore.mint, "logical_or") + + def logical_or(input, other, *, out=None): if use_pyboost() and has_logical_or: return call_ms_func(mindspore.mint.logical_or, input, other, out=out) return call_ms_func(ops.logical_or, input, other, out=out) + # logical_xor -has_logical_xor = hasattr(mindspore.mint, 'logical_xor') +has_logical_xor = hasattr(mindspore.mint, "logical_xor") + + def logical_xor(input, other, *, out=None): if use_pyboost() and has_logical_xor: return call_ms_func(mindspore.mint.logical_xor, input, other, out=out) return call_ms_func(ops.logical_xor, input, other, out=out) + # logit def logit(input, eps=None): return ops.logit(input, eps) + # hypot def hypot(input, other): return ops.hypot(input, other) + # i0 + # igamma def igamma(input, other): return ops.igamma(input, other) + # igammac def igammac(input, other): return ops.igammac(input, other) + # mul -has_mul = hasattr(mindspore.mint, 'mul') +has_mul = hasattr(mindspore.mint, "mul") + + def mul(input, other, *, out=None): if use_pyboost() and has_mul: return call_ms_func(mindspore.mint.mul, input, other, out=out) return call_ms_func(ops.mul, input, other, out=out) + # multiply def multiply(input, other): return mul(input, other) + # mvlgamma def mvlgamma(input, p): return ops.mvlgamma(input, p) + # nan_to_num -has_nan_to_num = hasattr(mindspore.mint, 'nan_to_num') +has_nan_to_num = hasattr(mindspore.mint, "nan_to_num") + + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): if use_pyboost() and has_nan_to_num and not ON_A1: - return call_ms_func(mindspore.mint.nan_to_num, input, nan, posinf, neginf, out=out) + return call_ms_func( + mindspore.mint.nan_to_num, input, nan, posinf, neginf, out=out + ) # 创建输入张量的副本 output = input.clone() @@ -450,53 +608,77 @@ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # 对于整数类型,使用给定值或默认值 default_posinf = core.iinfo(dtype).max if posinf is None else posinf default_neginf = core.iinfo(dtype).min if neginf is None else neginf - + # 替换 NaN if core.isnan(output).any(): - output = core.where(core.isnan(output), core.tensor(nan, dtype=output.dtype, device=output.device), output) - + output = core.where( + core.isnan(output), + core.tensor(nan, dtype=output.dtype, device=output.device), + output, + ) + # 替换正无穷大 if core.isinf(output).any() and (posinf is not None or output.is_floating_point()): - output = core.where((output == float('inf')) & core.isinf(output), core.tensor(default_posinf, dtype=output.dtype, device=output.device), output) - + output = core.where( + (output == float("inf")) & core.isinf(output), + core.tensor(default_posinf, dtype=output.dtype, device=output.device), + output, + ) + # 替换负无穷大 if core.isinf(output).any() and (neginf is not None or output.is_floating_point()): - output = core.where((output == float('-inf')) & core.isinf(output), - core.tensor(default_neginf, dtype=output.dtype, device=output.device), output) - + output = core.where( + (output == float("-inf")) & core.isinf(output), + core.tensor(default_neginf, dtype=output.dtype, device=output.device), + output, + ) + return output + # neg -has_neg = hasattr(mindspore.mint, 'neg') +has_neg = hasattr(mindspore.mint, "neg") + + def neg(input, *, out=None): if use_pyboost() and has_neg: return call_ms_func(mindspore.mint.neg, input, out=out) return call_ms_func(ops.neg, input, out=out) + # negative -has_negative = hasattr(mindspore.mint, 'negative') +has_negative = hasattr(mindspore.mint, "negative") + + def negative(input): return neg(input) + # nextafter def nextafter(input, other): return ops.nextafter(input, other) + # polygamma def polygamma(n, input): return ops.polygamma(n, input) + # positive def positive(input): return input + # pow -has_pow = hasattr(mindspore.mint, 'pow') +has_pow = hasattr(mindspore.mint, "pow") + + def pow(input, exponent, *, out=None): if use_pyboost() and has_pow: return call_ms_func(mindspore.mint.pow, input, exponent, out=out) return call_ms_func(ops.pow, input, exponent, out=out) + # quantized_batch_norm @@ -510,144 +692,295 @@ def pow(input, exponent, *, out=None): def rad2deg(input): return ops.rad2deg(input) + # real def real(input): return ops.real(input) + # reciprocal -has_reciprocal = hasattr(mindspore.mint, 'reciprocal') +has_reciprocal = hasattr(mindspore.mint, "reciprocal") + + def reciprocal(input, *, out=None): if use_pyboost() and has_reciprocal: return call_ms_func(mindspore.mint.reciprocal, input, out=out) return call_ms_func(ops.reciprocal, input, out=out) + # remainder -has_remainder = hasattr(mindspore.mint, 'remainder') +has_remainder = hasattr(mindspore.mint, "remainder") + + def remainder(input, other, *, out=None): if use_pyboost() and has_remainder: return call_ms_func(mindspore.mint.remainder, input, other, out=out) return call_ms_func(ops.remainder, input, other, out=out) + # round -has_round = hasattr(mindspore.mint, 'round') +has_round = hasattr(mindspore.mint, "round") + + def round(input, *, decimals=0): if use_pyboost() and has_round: return mindspore.mint.round(input, decimals=decimals) return ops.round(input, decimals=decimals) + # rsqrt -has_rsqrt = hasattr(mindspore.mint, 'rsqrt') +has_rsqrt = hasattr(mindspore.mint, "rsqrt") + + def rsqrt(input, *, out=None): if use_pyboost() and has_rsqrt: return call_ms_func(mindspore.mint.rsqrt, input, out=out) return call_ms_func(ops.rsqrt, input, out=out) + # sigmoid -has_sigmoid = hasattr(mindspore.mint, 'sigmoid') +has_sigmoid = hasattr(mindspore.mint, "sigmoid") + + def sigmoid(input, *, out=None): if use_pyboost() and has_sigmoid: return call_ms_func(mindspore.mint.sigmoid, input, out=out) return call_ms_func(ops.sigmoid, input, out=out) + # sign -has_sign = hasattr(mindspore.mint, 'sign') +has_sign = hasattr(mindspore.mint, "sign") + + def sign(input, *, out=None): if use_pyboost() and has_sign: return call_ms_func(mindspore.mint.sign, input, out=out) return call_ms_func(ops.sign, input, out=out) + # sgn # signbit # sin -has_sin = hasattr(mindspore.mint, 'sin') +has_sin = hasattr(mindspore.mint, "sin") + + def sin(input, *, out=None): if use_pyboost() and has_sin: return call_ms_func(mindspore.mint.sin, input, out=out) return call_ms_func(ops.sin, input, out=out) + # sinc -has_sinc = hasattr(mindspore.mint, 'sinc') +has_sinc = hasattr(mindspore.mint, "sinc") + + def sinc(input, *, out=None): if use_pyboost() and has_sinc: return call_ms_func(mindspore.mint.sinc, input, out=out) return call_ms_func(ops.sinc, input, out=out) + # sinh -has_sinh = hasattr(mindspore.mint, 'sinh') +has_sinh = hasattr(mindspore.mint, "sinh") + + def sinh(input, *, out=None): if use_pyboost() and has_sinh: return call_ms_func(mindspore.mint.sinh, input, out=out) return call_ms_func(ops.sinh, input, out=out) + # softmax def softmax(input, dim, *, dtype=None): if use_pyboost(): return mindspore.mint.nn.functional.softmax(input, dim, dtype=dtype) return ops.softmax(input, dim, dtype=dtype) + +def log_softmax(input, dim=None, dtype=None): + return core.nn.functional.log_softmax(input, dim, dtype) + + # sqrt -has_sqrt = hasattr(mindspore.mint, 'sqrt') +has_sqrt = hasattr(mindspore.mint, "sqrt") + + def sqrt(input, *, out=None): if use_pyboost() and has_sqrt: return call_ms_func(mindspore.mint.sqrt, input, out=out) return call_ms_func(ops.sqrt, input, out=out) + # square -has_square = hasattr(mindspore.mint, 'square') +has_square = hasattr(mindspore.mint, "square") + + def square(input, *, out=None): if use_pyboost() and has_square: return call_ms_func(mindspore.mint.square, input, out=out) return call_ms_func(ops.square, input, out=out) + # sub -has_sub = hasattr(mindspore.mint, 'sub') +has_sub = hasattr(mindspore.mint, "sub") + + def sub(input, other, *, alpha=1, out=None): if use_pyboost() and has_sub: return call_ms_func(mindspore.mint.sub, input, other, alpha=alpha, out=out) return call_ms_func(ops.sub, input, other, out=out) + # subtract def subtract(input, other): return sub(input, other) + # tan -has_tan = hasattr(mindspore.mint, 'tan') +has_tan = hasattr(mindspore.mint, "tan") + + def tan(input, *, out=None): if use_pyboost() and has_tan: return call_ms_func(mindspore.mint.tan, input, out=out) return call_ms_func(ops.tan, input, out=out) + # tanh -has_tanh = hasattr(mindspore.mint, 'tanh') +has_tanh = hasattr(mindspore.mint, "tanh") + + def tanh(input, *, out=None): if use_pyboost() and has_tanh: return call_ms_func(mindspore.mint.tanh, input, out=out) return call_ms_func(ops.tanh, input, out=out) + # true_divide def true_divide(input, other): return div(input, other) + # trunc -has_trunc = hasattr(mindspore.mint, 'trunc') +has_trunc = hasattr(mindspore.mint, "trunc") + + def trunc(input, *, out=None): if use_pyboost() and has_trunc: return call_ms_func(mindspore.mint.trunc, input, out=out) return call_ms_func(ops.trunc, input, out=out) + # xlogy -has_xlogy = hasattr(mindspore.mint, 'xlogy') +has_xlogy = hasattr(mindspore.mint, "xlogy") + + def xlogy(input, other, *, out=None): if use_pyboost() and has_xlogy: return call_ms_func(mindspore.mint.xlogy, input, other, out=out) return call_ms_func(ops.xlogy, input, other, out=out) + # relu def relu(input): if use_pyboost(): return mindspore.mint.nn.functional.relu(input) return ops.relu(input) -__all__ = ['abs', 'absolute', 'acos', 'acosh', 'add', 'addcdiv', 'addcmul', 'angle', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctan2', 'arctanh', 'arrcos', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'clamp', 'clip', 'cos', 'cosh', 'deg2rad', 'digamma', 'div', 'divide', 'erf', 'erfc', 'erfinv', 'exp', 'exp2', 'expm1', 'float_power', 'floor', 'floor_divide', 'fmod', 'frac', 'hypot', 'igamma', 'igammac', 'imag', 'lerp', 'lgamma', 'log', 'log1p', 'log2', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'logit', 'mul', 'multiply', 'mvlgamma', 'nan_to_num', 'neg', 'negative', 'nextafter', 'polygamma', 'positive', 'pow', 'rad2deg', 'real', 'reciprocal', 'remainder', 'round', 'rsqrt', 'sigmoid', 'sign', 'sin', 'sinc', 'sinh', 'softmax', 'sqrt', 'square', 'sub', 'subtract', 'tan', 'tanh', 'true_divide', 'trunc', 'xlogy', 'relu'] \ No newline at end of file + +__all__ = [ + "abs", + "absolute", + "acos", + "acosh", + "add", + "addcdiv", + "addcmul", + "angle", + "arccosh", + "arcsin", + "arcsinh", + "arctan", + "arctan2", + "arctanh", + "arrcos", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "clamp", + "clip", + "cos", + "cosh", + "deg2rad", + "digamma", + "div", + "divide", + "erf", + "erfc", + "erfinv", + "exp", + "exp2", + "expm1", + "float_power", + "floor", + "floor_divide", + "fmod", + "frac", + "hypot", + "igamma", + "igammac", + "imag", + "lerp", + "lgamma", + "log", + "log1p", + "log2", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logit", + "log_softmax", + "mul", + "multiply", + "mvlgamma", + "nan_to_num", + "neg", + "negative", + "nextafter", + "polygamma", + "positive", + "pow", + "rad2deg", + "real", + "reciprocal", + "remainder", + "round", + "rsqrt", + "sigmoid", + "sign", + "sin", + "sinc", + "sinh", + "softmax", + "sqrt", + "square", + "sub", + "subtract", + "tan", + "tanh", + "true_divide", + "trunc", + "xlogy", + "relu", +]