Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions mindtorch/_apis/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def identity(input):
def clone(input):
return cast(legacy.mul(input, 1), input.dtype)

py_max = max
def max(input):
return legacy.reduce_max(input, (), False)

Expand Down Expand Up @@ -106,6 +107,12 @@ def transpose_view(input, dim0, dim1):
def matmul(self, other):
if self.ndim > 2:
if self.ndim == other.ndim:
if self.shape[:-2] != other.shape[:-2]:
new_shape = ()
for i in range(self.ndim - 2):
new_shape += (py_max([self.shape[i], other.shape[i]]),)
self = broadcast_to(self, new_shape + self.shape[-2:])
other = broadcast_to(other, new_shape + other.shape[-2:])
return legacy.batch_mat_mul(self, other, False, False)
else:
self_shape = self.shape
Expand Down Expand Up @@ -988,9 +995,9 @@ def grid_sampler_2d(input, grid, mode='bilinear', padding_mode='zeros', align_co
def l1_loss(input, target, reduction='mean'):
loss = abs(sub(input, target))
if reduction == 'mean':
return mean(loss, (), False, False)
return mean(loss, (), False, None)
elif reduction == 'sum':
return sum(loss, (), False, False)
return sum(loss, (), False, None)
return loss

def leaky_relu(input, negative_slope):
Expand Down Expand Up @@ -1136,4 +1143,7 @@ def inplace_fill_tensor(input, value):
return input

def search_sorted(sorted_sequence, values, sorter, dtype, right):
return legacy.search_sorted(sorted_sequence, values, sorter, dtype, right)
return legacy.search_sorted(sorted_sequence, values, sorter, dtype, right)

def einsum(equation, operands):
return legacy.einsum(operands, equation)
2 changes: 1 addition & 1 deletion mindtorch/distributed/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates

import mindtorch
import mindtorch.distributed.tensor._ops # force import all built-in dtensor ops
# import mindtorch.distributed.tensor._ops # force import all built-in dtensor ops
from mindtorch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401
from mindtorch.distributed.tensor._api import (
distribute_module,
Expand Down
2 changes: 1 addition & 1 deletion mindtorch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red
else:
non_pad_mask = target
if weight is not None:
loss_weights = mindtorch.gather(weight, target, 0)
loss_weights = mindtorch.index_select(weight, 0, target)
orig_shape = inputs.shape
if inputs.ndim != 2:
inputs = inputs.view(orig_shape[:2] + (-1,))
Expand Down
41 changes: 41 additions & 0 deletions mindtorch/ops/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,46 @@ def cumsum(input, dim=None, dtype=None, **kwargs):
return execute('cumsum', input, dim, dtype)

# diag
def my_diag(input_tensor, diagonal=0):
"""
手动实现 torch.diag 的功能
参数:
input_tensor: 输入张量,可以是一维(向量)或二维(矩阵)
diagonal: 对角线的位置,0为主对角线,正数为上对角线,负数为下对角线
返回:
根据输入维度返回对角矩阵或对角线元素
"""
if input_tensor.dim() == 1: # 输入是向量,构建对角矩阵
n = input_tensor.size(0)
output = mindtorch.zeros(n, n, dtype=input_tensor.dtype, device=input_tensor.device)
for i in range(n):
output[i, i] = input_tensor[i]
return output

elif input_tensor.dim() == 2: # 输入是矩阵,提取对角线元素
rows, cols = input_tensor.shape
if diagonal >= 0:
diag_len = min(rows, cols - diagonal)
else:
diag_len = min(rows + diagonal, cols)

if diag_len <= 0: # 对角线长度无效则返回空张量
return mindtorch.tensor([], dtype=input_tensor.dtype, device=input_tensor.device)

output = mindtorch.zeros(diag_len, dtype=input_tensor.dtype, device=input_tensor.device)
for i in range(diag_len):
if diagonal >= 0:
output[i] = input_tensor[i, i + diagonal]
else:
output[i] = input_tensor[i - diagonal, i]
return output

else:
raise RuntimeError("输入张量必须是一维或二维")

def diag(input, diagonal=0, *, out=None):
if input.device.type == 'cuda':
return my_diag(input, diagonal)
return execute('diag', input, diagonal)

# diag_embed
Expand Down Expand Up @@ -639,6 +678,8 @@ def einsum(equation, *operands):
"""
if isinstance(operands[0], (list, tuple)):
operands = operands[0]
if operands[0].device.type == 'cuda':
return execute('einsum', equation, operands, device=operands[0].device)
_equation, _operands = _einsum_convert_sublist(equation, *operands)
_einsum_check_inputargs(_equation, _operands)
return _einsum(_equation, _operands)
Expand Down
51 changes: 51 additions & 0 deletions mindtorch/ops/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,59 @@ def prod(input, dim=None, keepdim=False, *, dtype=None):
# nanquantile

# std
def my_std(input_tensor, dim=None, unbiased=True, keepdim=False):
"""
手动实现类似 torch.std 的功能,计算张量的标准差。

参数:
input_tensor (torch.Tensor): 输入张量。
dim (int 或 tuple, 可选): 要计算标准差的维度。默认为 None,计算全局标准差。
unbiased (bool, 可选): 是否使用无偏估计 (贝塞尔校正)。默认为 True。
keepdim (bool, 可选): 输出是否保持输入张量的维度。默认为 False。

返回:
torch.Tensor: 包含标准差值的张量。
"""
# 处理空张量输入
if input_tensor.numel() == 0:
raise ValueError("my_std(): input tensor is empty")

# 如果未指定 dim,则计算全局标准差
if dim is None:
# 计算均值
mean = input_tensor.mean()
# 计算与均值的平方差
squared_diff = (input_tensor - mean) ** 2
# 计算平方差的平均值(方差)
# 根据 unbiased 选择分母
n = input_tensor.numel()
divisor = n - 1 if unbiased else n
variance = squared_diff.sum() / divisor
# 标准差是方差的平方根
std_dev = mindtorch.sqrt(variance)
return std_dev

# 如果指定了 dim,则沿指定维度计算标准差
else:
# 计算沿指定维度的均值,keepdim=True 为了广播
mean = input_tensor.mean(dim=dim, keepdim=True)
# 计算平方差
squared_diff = (input_tensor - mean) ** 2
# 计算沿指定维度的平方差和
sum_squared_diff = squared_diff.sum(dim=dim, keepdim=keepdim)
# 获取沿指定维度缩减后的元素数
n = input_tensor.size(dim) if isinstance(dim, int) else mindtorch.prod(mindtorch.tensor([input_tensor.size(d) for d in dim])).item()
divisor = (n - 1) if unbiased else n
# 计算方差
variance = sum_squared_diff / divisor
# 标准差是方差的平方根
std_dev = mindtorch.sqrt(variance)
return std_dev

def std(input, dim=None, *, correction=1, keepdim=False, **kwargs):
dim = kwargs.pop('axis', dim)
if input.device.type == 'cuda':
return my_std(input, dim, bool(correction), keepdim)
return execute('std', input, dim, correction, keepdim)

# std_mean
Expand Down