From 50ff9685f6e5b128ce0cbf94334636a9c78d15d0 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Mon, 22 Sep 2025 08:00:51 +0000 Subject: [PATCH] fix torch.__version__ --- mindtorch/__init__.py | 4 +- mindtorch/_apis/gpu.py | 30 ++++++++ mindtorch/_apis/npu.py | 5 +- mindtorch/_tensor.py | 2 +- mindtorch/nn/functional.py | 140 +++++++++++++++++++++++++++++++++++-- mindtorch/ops/array.py | 5 ++ mindtorch/ops/pointwise.py | 1 + 7 files changed, 178 insertions(+), 9 deletions(-) diff --git a/mindtorch/__init__.py b/mindtorch/__init__.py index 0b60c2e68..827b4b6ec 100644 --- a/mindtorch/__init__.py +++ b/mindtorch/__init__.py @@ -178,9 +178,9 @@ def _running_with_deploy(): from ._lowrank import svd_lowrank from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state +__version__ = 'test_version_no_value' + from .torch_proxy import initialize_torch_proxy, setup_metadata_patch initialize_torch_proxy() setup_metadata_patch() - -__version__ = 'test_version_no_value' \ No newline at end of file diff --git a/mindtorch/_apis/gpu.py b/mindtorch/_apis/gpu.py index b59c8db14..7e2d63a16 100644 --- a/mindtorch/_apis/gpu.py +++ b/mindtorch/_apis/gpu.py @@ -265,6 +265,8 @@ def less(input, other): return legacy.less(input, other) def select(condition, x, y): + if 0 in condition.shape: + return mindspore.Tensor(Tensor_(shape=condition.shape, dtype=x.dtype)) if isinstance(x, numbers.Number) or x.ndim == 0: x = fill_scalar(condition.shape, x, None) if isinstance(y, numbers.Number) or y.ndim == 0: @@ -1150,3 +1152,31 @@ def einsum(equation, operands): def unique2(input, sorted, return_inverse, return_counts): outs = legacy.unique(input) return outs + (None,) + +def logaddexp(input, other): + m = maximum(input, other) + abs_val = abs(sub(input, other)) + exp_val = exp(neg(abs_val)) + y = add(m, log1p(exp_val)) + return y + +def kl_div(input, target, reduction, log_target): + if log_target: + target = log(target) + + if reduction == 'batchmean': + kl_div_sum = legacy.kl_div_loss(input, target, 'sum') + # shape = input.shape + # batch_size = shape[0] + # return div(kl_div_sum, batch_size) + return kl_div_sum + + if reduction == 'mean': + kl_div_sum = legacy.kl_div_loss(input, target, 'sum') + shape = input.shape + total_size = 1 + for dim in shape: + total_size = total_size * dim + return div(kl_div_sum, total_size) + + return legacy.kl_div_loss(input, target, reduction) diff --git a/mindtorch/_apis/npu.py b/mindtorch/_apis/npu.py index a4f8fcdab..83bd25115 100644 --- a/mindtorch/_apis/npu.py +++ b/mindtorch/_apis/npu.py @@ -1612,4 +1612,7 @@ def new_empty(input, size, dtype): return pyboost.new_empty_op(input, size, dtype, 'Ascend') def new_ones(input, size, dtype): - return pyboost.new_ones_op(input, size, dtype) \ No newline at end of file + return pyboost.new_ones_op(input, size, dtype) + +def kl_div(input, target, reduction, log_target): + return pyboost.kl_div_op(input, target, reduction, log_target) diff --git a/mindtorch/_tensor.py b/mindtorch/_tensor.py index ca862a168..e9f825d97 100644 --- a/mindtorch/_tensor.py +++ b/mindtorch/_tensor.py @@ -177,7 +177,7 @@ def cuda(self, device=None, non_blocking=False): if DEVICE_TARGET == 'Ascend': return self.npu(device, non_blocking) if device is None: - device = device_('gpu', 0) + device = device_('cuda', 0) return self.to(device, non_blocking=non_blocking) def requires_grad_(self, requires_grad=True): diff --git a/mindtorch/nn/functional.py b/mindtorch/nn/functional.py index ba128fd8a..ebc93fb7b 100644 --- a/mindtorch/nn/functional.py +++ b/mindtorch/nn/functional.py @@ -557,10 +557,16 @@ def smooth_l1_loss(input, target, beta=1.0, reduction='none'): target = target.to(mindtorch.float32) return ops.smooth_l1_loss(input, target, beta, reduction) -def kl_div(logits, labels, reduction='mean', log_target=False): - if log_target: - labels = ops.log(labels) - return ops.kl_div(logits, labels, reduction) +def kl_div(input, target, reduction='mean', log_target=False): + if reduction == 'batchmean': + reduced = execute('kl_div', input, target, 'sum', log_target) + else: + reduced = execute('kl_div', input, target, reduction, log_target) + + if reduction == 'batchmean' and input.ndim != 0: + reduced = mindtorch.div(reduced, input.shape[0]) + + return reduced def softmax(input, dim=-1, *, dtype=None): if dtype is not None: @@ -1602,8 +1608,132 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1): def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): return execute('col2im', input, output_size, kernel_size, dilation, padding, stride) +# def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False): +# return execute('ctc_loss', log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity) + def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False): - return execute('ctc_loss', log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity) + """ + 使用向量化操作手动实现 CTC Loss,提升计算效率。 + 支持批处理,并在内部使用张量操作避免循环。 + + 参数: + log_probs: Tensor of size (T, N, C), 其中 T=input length, N=batch size, C=number of classes (包括空白符). + 通常应经过 log_softmax 处理。 + targets: Tensor of size (N, S) 或 (sum(target_lengths)), 表示目标序列。不包含空白符。 + input_lengths: Tensor or tuple of size (N), 表示每个输入序列的实际长度。 + target_lengths: Tensor or tuple of size (N), 表示每个目标序列的实际长度。 + blank (int, optional): 空白符的类别索引。默认为 0。 + reduction (str, optional): 指定损失的缩减方式:'none' | 'mean' | 'sum'. 默认为 'mean'. + zero_infinity (bool, optional): 是否将无限损失(及其梯度)归零。默认为 False。 + + 返回: + Tensor: 计算出的 CTC 损失。 + """ + T, N, C = log_probs.size() + device = log_probs.device + dtype = log_probs.dtype + + # 初始化损失张量 + losses = mindtorch.zeros(N, device=device, dtype=dtype) + + # 处理 targets 的格式 (N, S) 或 (sum(target_lengths)) + if targets.dim() == 1: + # targets 是 1D 的 concatenated 形式 + targets_ = targets + else: + # targets 是 2D 的 (N, S) 形式 + targets_ = targets.view(-1) + + # 遍历批次中的每个样本 + for n in range(N): + T_n = input_lengths[n] + S_n = target_lengths[n] + + if S_n == 0: + # 如果目标长度为0,则损失为 -log(在空白符上的概率和) + blank_log_probs = log_probs[:T_n, n, blank] + losses[n] = -mindtorch.sum(blank_log_probs) + if zero_infinity and mindtorch.isinf(losses[n]): + losses[n] = 0.0 + continue + + # 获取当前样本的目标序列 + if targets.dim() == 1: + start_index = sum(target_lengths[:n]) + end_index = start_index + S_n + target_seq = targets_[start_index:end_index] + else: + target_seq = targets[n, :S_n] + + # 构建扩展目标序列 (长度 L = 2 * S_n + 1) + extended_targets = mindtorch.zeros(2 * S_n + 1, device=device, dtype=mindtorch.long) + extended_targets[0] = blank + extended_targets[1::2] = target_seq + extended_targets[2::2] = blank + L = len(extended_targets) + + # 初始化前向变量 alpha, 形状为 (T_n, L) + alpha = mindtorch.full((T_n, L), mindtorch.finfo(dtype).min, device=device, dtype=dtype) + + # 初始化第一个时间步 + alpha[0, 0] = log_probs[0, n, extended_targets[0]] # 从空白符开始 + if L > 1: + alpha[0, 1] = log_probs[0, n, extended_targets[1]] # 从第一个真实字符开始 + + # 前向递归计算 alpha + for t in range(1, T_n): + # 获取当前时间步对所有扩展目标字符的 log_probs + # log_probs_t 形状: (L,) + log_probs_t = log_probs[t, n, extended_targets] + + # 初始化当前时间步的 alpha_prev,用于向量化计算 + alpha_prev = alpha[t-1] + + # 情况1: 从 s 转移过来 (停留) + stay_log_prob = alpha_prev + log_probs_t + + # 情况2: 从 s-1 转移过来 (移动一步) + move_one_log_prob = mindtorch.empty_like(stay_log_prob) + move_one_log_prob[0] = -float('inf') # 第一个位置没有 s-1 + move_one_log_prob[1:] = alpha_prev[:-1] + log_probs_t[1:] + + # 情况3: 从 s-2 转移过来 (移动两步,需满足条件) + move_two_log_prob = mindtorch.empty_like(stay_log_prob) + move_two_log_prob[:2] = -float('inf') # 前两个位置没有 s-2 + # 条件: s >= 2, 且当前字符不是空白符,且当前字符与 s-2 处的字符不同 + condition = (extended_targets[2:] != extended_targets[:-2]) & (extended_targets[2:] != blank) + # 将条件应用到 alpha_prev[s-2] 上 + eligible_s2 = alpha_prev[:-2].clone() + eligible_s2[~condition] = -float('inf') # 不满足条件的设为 -inf + move_two_log_prob[2:] = eligible_s2 + log_probs_t[2:] + + # 使用 logaddexp 合并三种情况的概率 + # 首先合并 stay 和 move_one + combined_stay_move = mindtorch.logaddexp(stay_log_prob, move_one_log_prob) + # 再合并 move_two + alpha[t] = mindtorch.logaddexp(combined_stay_move, move_two_log_prob) + + # 计算最终的总对数似然 (最后时间步的最后两个位置) + if L > 1: + total_log_prob = mindtorch.logaddexp(alpha[T_n-1, L-1], alpha[T_n-1, L-2]) + else: + total_log_prob = alpha[T_n-1, L-1] + + losses[n] = -total_log_prob + + # 处理 zero_infinity + if zero_infinity and mindtorch.isinf(losses[n]): + losses[n] = 0.0 + + # 根据 reduction 参数返回损失 + if reduction == 'none': + return losses + elif reduction == 'sum': + return losses.sum() + elif reduction == 'mean': + return losses.mean() + else: + raise ValueError("reduction should be 'none', 'sum', or 'mean'.") def one_hot(tensor, num_classes=-1): return execute('one_hot', tensor, num_classes) diff --git a/mindtorch/ops/array.py b/mindtorch/ops/array.py index fd2edd247..88186d6af 100644 --- a/mindtorch/ops/array.py +++ b/mindtorch/ops/array.py @@ -825,6 +825,7 @@ def _slice_helper(tensor, slice_spec, do_update=False, updates=None): strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask ) + tensor = tensor.clone() if not advanced_indices: return tensor @@ -958,6 +959,8 @@ def getitem(self, slice_spec): def setitem(a, slice_spec, updates): """Implementation of ndarray._with_index_*.""" + if 0 in updates.shape: + return a if ( isinstance(slice_spec, bool) or ( @@ -981,6 +984,8 @@ def strided_slice_update(input, begin, end, strides, update, begin_mask=0, end_m if isinstance(update, (int, float, bool)): update = mindtorch.tensor(update, device=input.device, dtype=input.dtype) sliced_tensor = execute('strided_slice', input, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) + if 0 in sliced_tensor.shape: + return input if update.shape != sliced_tensor.shape: update = update.broadcast_to(sliced_tensor.shape) update = update - sliced_tensor diff --git a/mindtorch/ops/pointwise.py b/mindtorch/ops/pointwise.py index f2a2d5435..9d77cf597 100644 --- a/mindtorch/ops/pointwise.py +++ b/mindtorch/ops/pointwise.py @@ -628,6 +628,7 @@ def log_softmax(input, dim=None, dtype=None): "log", "log1p", "log2", + "logaddexp", "logical_and", "logical_not", "logical_or",