From 1dcc73779bc484011494fa78e1629e7cab78ca20 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Sun, 21 Sep 2025 13:06:41 +0000 Subject: [PATCH] fix x-z class on GPU --- mindtorch/_apis/gpu.py | 16 ++++++-- mindtorch/distributed/tensor/__init__.py | 2 +- mindtorch/nn/functional.py | 2 +- mindtorch/ops/other.py | 41 +++++++++++++++++++ mindtorch/ops/reduction.py | 51 ++++++++++++++++++++++++ 5 files changed, 107 insertions(+), 5 deletions(-) diff --git a/mindtorch/_apis/gpu.py b/mindtorch/_apis/gpu.py index 4e3968052..5affa4ca7 100644 --- a/mindtorch/_apis/gpu.py +++ b/mindtorch/_apis/gpu.py @@ -71,6 +71,7 @@ def identity(input): def clone(input): return cast(legacy.mul(input, 1), input.dtype) +py_max = max def max(input): return legacy.reduce_max(input, (), False) @@ -106,6 +107,12 @@ def transpose_view(input, dim0, dim1): def matmul(self, other): if self.ndim > 2: if self.ndim == other.ndim: + if self.shape[:-2] != other.shape[:-2]: + new_shape = () + for i in range(self.ndim - 2): + new_shape += (py_max([self.shape[i], other.shape[i]]),) + self = broadcast_to(self, new_shape + self.shape[-2:]) + other = broadcast_to(other, new_shape + other.shape[-2:]) return legacy.batch_mat_mul(self, other, False, False) else: self_shape = self.shape @@ -988,9 +995,9 @@ def grid_sampler_2d(input, grid, mode='bilinear', padding_mode='zeros', align_co def l1_loss(input, target, reduction='mean'): loss = abs(sub(input, target)) if reduction == 'mean': - return mean(loss, (), False, False) + return mean(loss, (), False, None) elif reduction == 'sum': - return sum(loss, (), False, False) + return sum(loss, (), False, None) return loss def leaky_relu(input, negative_slope): @@ -1136,4 +1143,7 @@ def inplace_fill_tensor(input, value): return input def search_sorted(sorted_sequence, values, sorter, dtype, right): - return legacy.search_sorted(sorted_sequence, values, sorter, dtype, right) \ No newline at end of file + return legacy.search_sorted(sorted_sequence, values, sorter, dtype, right) + +def einsum(equation, operands): + return legacy.einsum(operands, equation) \ No newline at end of file diff --git a/mindtorch/distributed/tensor/__init__.py b/mindtorch/distributed/tensor/__init__.py index 4063d9d91..2fba75884 100644 --- a/mindtorch/distributed/tensor/__init__.py +++ b/mindtorch/distributed/tensor/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import mindtorch -import mindtorch.distributed.tensor._ops # force import all built-in dtensor ops +# import mindtorch.distributed.tensor._ops # force import all built-in dtensor ops from mindtorch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401 from mindtorch.distributed.tensor._api import ( distribute_module, diff --git a/mindtorch/nn/functional.py b/mindtorch/nn/functional.py index 8e34679ec..49dcac19d 100644 --- a/mindtorch/nn/functional.py +++ b/mindtorch/nn/functional.py @@ -361,7 +361,7 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red else: non_pad_mask = target if weight is not None: - loss_weights = mindtorch.gather(weight, target, 0) + loss_weights = mindtorch.index_select(weight, 0, target) orig_shape = inputs.shape if inputs.ndim != 2: inputs = inputs.view(orig_shape[:2] + (-1,)) diff --git a/mindtorch/ops/other.py b/mindtorch/ops/other.py index ba12ecc69..f2098d395 100644 --- a/mindtorch/ops/other.py +++ b/mindtorch/ops/other.py @@ -108,7 +108,46 @@ def cumsum(input, dim=None, dtype=None, **kwargs): return execute('cumsum', input, dim, dtype) # diag +def my_diag(input_tensor, diagonal=0): + """ + 手动实现 torch.diag 的功能 + 参数: + input_tensor: 输入张量,可以是一维(向量)或二维(矩阵) + diagonal: 对角线的位置,0为主对角线,正数为上对角线,负数为下对角线 + 返回: + 根据输入维度返回对角矩阵或对角线元素 + """ + if input_tensor.dim() == 1: # 输入是向量,构建对角矩阵 + n = input_tensor.size(0) + output = mindtorch.zeros(n, n, dtype=input_tensor.dtype, device=input_tensor.device) + for i in range(n): + output[i, i] = input_tensor[i] + return output + + elif input_tensor.dim() == 2: # 输入是矩阵,提取对角线元素 + rows, cols = input_tensor.shape + if diagonal >= 0: + diag_len = min(rows, cols - diagonal) + else: + diag_len = min(rows + diagonal, cols) + + if diag_len <= 0: # 对角线长度无效则返回空张量 + return mindtorch.tensor([], dtype=input_tensor.dtype, device=input_tensor.device) + + output = mindtorch.zeros(diag_len, dtype=input_tensor.dtype, device=input_tensor.device) + for i in range(diag_len): + if diagonal >= 0: + output[i] = input_tensor[i, i + diagonal] + else: + output[i] = input_tensor[i - diagonal, i] + return output + + else: + raise RuntimeError("输入张量必须是一维或二维") + def diag(input, diagonal=0, *, out=None): + if input.device.type == 'cuda': + return my_diag(input, diagonal) return execute('diag', input, diagonal) # diag_embed @@ -639,6 +678,8 @@ def einsum(equation, *operands): """ if isinstance(operands[0], (list, tuple)): operands = operands[0] + if operands[0].device.type == 'cuda': + return execute('einsum', equation, operands, device=operands[0].device) _equation, _operands = _einsum_convert_sublist(equation, *operands) _einsum_check_inputargs(_equation, _operands) return _einsum(_equation, _operands) diff --git a/mindtorch/ops/reduction.py b/mindtorch/ops/reduction.py index 204ec2b9d..71afdbaf2 100644 --- a/mindtorch/ops/reduction.py +++ b/mindtorch/ops/reduction.py @@ -163,8 +163,59 @@ def prod(input, dim=None, keepdim=False, *, dtype=None): # nanquantile # std +def my_std(input_tensor, dim=None, unbiased=True, keepdim=False): + """ + 手动实现类似 torch.std 的功能,计算张量的标准差。 + + 参数: + input_tensor (torch.Tensor): 输入张量。 + dim (int 或 tuple, 可选): 要计算标准差的维度。默认为 None,计算全局标准差。 + unbiased (bool, 可选): 是否使用无偏估计 (贝塞尔校正)。默认为 True。 + keepdim (bool, 可选): 输出是否保持输入张量的维度。默认为 False。 + + 返回: + torch.Tensor: 包含标准差值的张量。 + """ + # 处理空张量输入 + if input_tensor.numel() == 0: + raise ValueError("my_std(): input tensor is empty") + + # 如果未指定 dim,则计算全局标准差 + if dim is None: + # 计算均值 + mean = input_tensor.mean() + # 计算与均值的平方差 + squared_diff = (input_tensor - mean) ** 2 + # 计算平方差的平均值(方差) + # 根据 unbiased 选择分母 + n = input_tensor.numel() + divisor = n - 1 if unbiased else n + variance = squared_diff.sum() / divisor + # 标准差是方差的平方根 + std_dev = mindtorch.sqrt(variance) + return std_dev + + # 如果指定了 dim,则沿指定维度计算标准差 + else: + # 计算沿指定维度的均值,keepdim=True 为了广播 + mean = input_tensor.mean(dim=dim, keepdim=True) + # 计算平方差 + squared_diff = (input_tensor - mean) ** 2 + # 计算沿指定维度的平方差和 + sum_squared_diff = squared_diff.sum(dim=dim, keepdim=keepdim) + # 获取沿指定维度缩减后的元素数 + n = input_tensor.size(dim) if isinstance(dim, int) else mindtorch.prod(mindtorch.tensor([input_tensor.size(d) for d in dim])).item() + divisor = (n - 1) if unbiased else n + # 计算方差 + variance = sum_squared_diff / divisor + # 标准差是方差的平方根 + std_dev = mindtorch.sqrt(variance) + return std_dev + def std(input, dim=None, *, correction=1, keepdim=False, **kwargs): dim = kwargs.pop('axis', dim) + if input.device.type == 'cuda': + return my_std(input, dim, bool(correction), keepdim) return execute('std', input, dim, correction, keepdim) # std_mean