Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mindtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ def _running_with_deploy():
from ._lowrank import svd_lowrank
from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state

__version__ = 'test_version_no_value'

from .torch_proxy import initialize_torch_proxy, setup_metadata_patch
initialize_torch_proxy()
setup_metadata_patch()


__version__ = 'test_version_no_value'
30 changes: 30 additions & 0 deletions mindtorch/_apis/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def less(input, other):
return legacy.less(input, other)

def select(condition, x, y):
if 0 in condition.shape:
return mindspore.Tensor(Tensor_(shape=condition.shape, dtype=x.dtype))
if isinstance(x, numbers.Number) or x.ndim == 0:
x = fill_scalar(condition.shape, x, None)
if isinstance(y, numbers.Number) or y.ndim == 0:
Expand Down Expand Up @@ -1150,3 +1152,31 @@ def einsum(equation, operands):
def unique2(input, sorted, return_inverse, return_counts):
outs = legacy.unique(input)
return outs + (None,)

def logaddexp(input, other):
m = maximum(input, other)
abs_val = abs(sub(input, other))
exp_val = exp(neg(abs_val))
y = add(m, log1p(exp_val))
return y

def kl_div(input, target, reduction, log_target):
if log_target:
target = log(target)

if reduction == 'batchmean':
kl_div_sum = legacy.kl_div_loss(input, target, 'sum')
# shape = input.shape
# batch_size = shape[0]
# return div(kl_div_sum, batch_size)
return kl_div_sum

if reduction == 'mean':
kl_div_sum = legacy.kl_div_loss(input, target, 'sum')
shape = input.shape
total_size = 1
for dim in shape:
total_size = total_size * dim
return div(kl_div_sum, total_size)

return legacy.kl_div_loss(input, target, reduction)
5 changes: 4 additions & 1 deletion mindtorch/_apis/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,4 +1612,7 @@ def new_empty(input, size, dtype):
return pyboost.new_empty_op(input, size, dtype, 'Ascend')

def new_ones(input, size, dtype):
return pyboost.new_ones_op(input, size, dtype)
return pyboost.new_ones_op(input, size, dtype)

def kl_div(input, target, reduction, log_target):
return pyboost.kl_div_op(input, target, reduction, log_target)
2 changes: 1 addition & 1 deletion mindtorch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def cuda(self, device=None, non_blocking=False):
if DEVICE_TARGET == 'Ascend':
return self.npu(device, non_blocking)
if device is None:
device = device_('gpu', 0)
device = device_('cuda', 0)
return self.to(device, non_blocking=non_blocking)

def requires_grad_(self, requires_grad=True):
Expand Down
140 changes: 135 additions & 5 deletions mindtorch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,10 +557,16 @@ def smooth_l1_loss(input, target, beta=1.0, reduction='none'):
target = target.to(mindtorch.float32)
return ops.smooth_l1_loss(input, target, beta, reduction)

def kl_div(logits, labels, reduction='mean', log_target=False):
if log_target:
labels = ops.log(labels)
return ops.kl_div(logits, labels, reduction)
def kl_div(input, target, reduction='mean', log_target=False):
if reduction == 'batchmean':
reduced = execute('kl_div', input, target, 'sum', log_target)
else:
reduced = execute('kl_div', input, target, reduction, log_target)

if reduction == 'batchmean' and input.ndim != 0:
reduced = mindtorch.div(reduced, input.shape[0])

return reduced

def softmax(input, dim=-1, *, dtype=None):
if dtype is not None:
Expand Down Expand Up @@ -1602,8 +1608,132 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
return execute('col2im', input, output_size, kernel_size, dilation, padding, stride)

# def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
# return execute('ctc_loss', log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity)

def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
return execute('ctc_loss', log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity)
"""
使用向量化操作手动实现 CTC Loss,提升计算效率。
支持批处理,并在内部使用张量操作避免循环。

参数:
log_probs: Tensor of size (T, N, C), 其中 T=input length, N=batch size, C=number of classes (包括空白符).
通常应经过 log_softmax 处理。
targets: Tensor of size (N, S) 或 (sum(target_lengths)), 表示目标序列。不包含空白符。
input_lengths: Tensor or tuple of size (N), 表示每个输入序列的实际长度。
target_lengths: Tensor or tuple of size (N), 表示每个目标序列的实际长度。
blank (int, optional): 空白符的类别索引。默认为 0。
reduction (str, optional): 指定损失的缩减方式:'none' | 'mean' | 'sum'. 默认为 'mean'.
zero_infinity (bool, optional): 是否将无限损失(及其梯度)归零。默认为 False。

返回:
Tensor: 计算出的 CTC 损失。
"""
T, N, C = log_probs.size()
device = log_probs.device
dtype = log_probs.dtype

# 初始化损失张量
losses = mindtorch.zeros(N, device=device, dtype=dtype)

# 处理 targets 的格式 (N, S) 或 (sum(target_lengths))
if targets.dim() == 1:
# targets 是 1D 的 concatenated 形式
targets_ = targets
else:
# targets 是 2D 的 (N, S) 形式
targets_ = targets.view(-1)

# 遍历批次中的每个样本
for n in range(N):
T_n = input_lengths[n]
S_n = target_lengths[n]

if S_n == 0:
# 如果目标长度为0,则损失为 -log(在空白符上的概率和)
blank_log_probs = log_probs[:T_n, n, blank]
losses[n] = -mindtorch.sum(blank_log_probs)
if zero_infinity and mindtorch.isinf(losses[n]):
losses[n] = 0.0
continue

# 获取当前样本的目标序列
if targets.dim() == 1:
start_index = sum(target_lengths[:n])
end_index = start_index + S_n
target_seq = targets_[start_index:end_index]
else:
target_seq = targets[n, :S_n]

# 构建扩展目标序列 (长度 L = 2 * S_n + 1)
extended_targets = mindtorch.zeros(2 * S_n + 1, device=device, dtype=mindtorch.long)
extended_targets[0] = blank
extended_targets[1::2] = target_seq
extended_targets[2::2] = blank
L = len(extended_targets)

# 初始化前向变量 alpha, 形状为 (T_n, L)
alpha = mindtorch.full((T_n, L), mindtorch.finfo(dtype).min, device=device, dtype=dtype)

# 初始化第一个时间步
alpha[0, 0] = log_probs[0, n, extended_targets[0]] # 从空白符开始
if L > 1:
alpha[0, 1] = log_probs[0, n, extended_targets[1]] # 从第一个真实字符开始

# 前向递归计算 alpha
for t in range(1, T_n):
# 获取当前时间步对所有扩展目标字符的 log_probs
# log_probs_t 形状: (L,)
log_probs_t = log_probs[t, n, extended_targets]

# 初始化当前时间步的 alpha_prev,用于向量化计算
alpha_prev = alpha[t-1]

# 情况1: 从 s 转移过来 (停留)
stay_log_prob = alpha_prev + log_probs_t

# 情况2: 从 s-1 转移过来 (移动一步)
move_one_log_prob = mindtorch.empty_like(stay_log_prob)
move_one_log_prob[0] = -float('inf') # 第一个位置没有 s-1
move_one_log_prob[1:] = alpha_prev[:-1] + log_probs_t[1:]

# 情况3: 从 s-2 转移过来 (移动两步,需满足条件)
move_two_log_prob = mindtorch.empty_like(stay_log_prob)
move_two_log_prob[:2] = -float('inf') # 前两个位置没有 s-2
# 条件: s >= 2, 且当前字符不是空白符,且当前字符与 s-2 处的字符不同
condition = (extended_targets[2:] != extended_targets[:-2]) & (extended_targets[2:] != blank)
# 将条件应用到 alpha_prev[s-2] 上
eligible_s2 = alpha_prev[:-2].clone()
eligible_s2[~condition] = -float('inf') # 不满足条件的设为 -inf
move_two_log_prob[2:] = eligible_s2 + log_probs_t[2:]

# 使用 logaddexp 合并三种情况的概率
# 首先合并 stay 和 move_one
combined_stay_move = mindtorch.logaddexp(stay_log_prob, move_one_log_prob)
# 再合并 move_two
alpha[t] = mindtorch.logaddexp(combined_stay_move, move_two_log_prob)

# 计算最终的总对数似然 (最后时间步的最后两个位置)
if L > 1:
total_log_prob = mindtorch.logaddexp(alpha[T_n-1, L-1], alpha[T_n-1, L-2])
else:
total_log_prob = alpha[T_n-1, L-1]

losses[n] = -total_log_prob

# 处理 zero_infinity
if zero_infinity and mindtorch.isinf(losses[n]):
losses[n] = 0.0

# 根据 reduction 参数返回损失
if reduction == 'none':
return losses
elif reduction == 'sum':
return losses.sum()
elif reduction == 'mean':
return losses.mean()
else:
raise ValueError("reduction should be 'none', 'sum', or 'mean'.")

def one_hot(tensor, num_classes=-1):
return execute('one_hot', tensor, num_classes)
Expand Down
5 changes: 5 additions & 0 deletions mindtorch/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ def _slice_helper(tensor, slice_spec, do_update=False, updates=None):
strides,
begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask
)
tensor = tensor.clone()

if not advanced_indices:
return tensor
Expand Down Expand Up @@ -958,6 +959,8 @@ def getitem(self, slice_spec):

def setitem(a, slice_spec, updates):
"""Implementation of ndarray._with_index_*."""
if 0 in updates.shape:
return a
if (
isinstance(slice_spec, bool)
or (
Expand All @@ -981,6 +984,8 @@ def strided_slice_update(input, begin, end, strides, update, begin_mask=0, end_m
if isinstance(update, (int, float, bool)):
update = mindtorch.tensor(update, device=input.device, dtype=input.dtype)
sliced_tensor = execute('strided_slice', input, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
if 0 in sliced_tensor.shape:
return input
if update.shape != sliced_tensor.shape:
update = update.broadcast_to(sliced_tensor.shape)
update = update - sliced_tensor
Expand Down
1 change: 1 addition & 0 deletions mindtorch/ops/pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def log_softmax(input, dim=None, dtype=None):
"log",
"log1p",
"log2",
"logaddexp",
"logical_and",
"logical_not",
"logical_or",
Expand Down