diff --git a/mindtorch/_apis/npu.py b/mindtorch/_apis/npu.py index c647f380b..4473ded28 100644 --- a/mindtorch/_apis/npu.py +++ b/mindtorch/_apis/npu.py @@ -791,10 +791,10 @@ def select(condition, input, other): return pyboost.select_op(condition, input, other) return legacy.select(condition, input, other) -def mean(input, axis, keepdims, dtype): +def mean(input, dim, keepdim, dtype): if use_pyboost(): - return pyboost.mean_ext_op(input, axis, keepdims, dtype) - return legacy.reduce_mean(input, axis, keepdims) + return pyboost.mean_ext_op(input, dim, keepdim, dtype) + return legacy.reduce_mean(input, dim, keepdim) def index(input, index): if use_pyboost(): @@ -1552,9 +1552,29 @@ def one_hot(tensor, num_classes): return legacy.one_hot(tensor, num_classes, on_value, off_value, -1) def var(input, dim=None, correction=1, keepdim=False): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return pyboost.var_op(input, dim, correction, keepdim) - return legacy.var(input, dim, correction, keepdim) + if dim is None: + input_mean = mean(input, (), False, None) + else: + input_mean = mean(input, dim=dim, keepdim=True, dtype=None) + + # 计算与均值的平方差 + squared_diff = pow(sub(input, input_mean, 1), 2) + # 计算方差 + if dim is None: + variance = mean(squared_diff, (), False, None) + n = input.numel() # 总元素个数 + else: + variance = mean(squared_diff, dim=dim, keepdim=keepdim, dtype=None) + n = input.size(dim) # 指定维度的元素个数 + + # 无偏估计校正 + if correction and n > 1: + variance = mul(variance, (n / (n - 1))) + + return variance + def linspace(start, end, steps, dtype=None): if use_pyboost() and not ON_ORANGE_PI: