From fe37c11ed09909537060c5788907504bbc6cdd06 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Mon, 21 Jul 2025 20:19:08 +0800 Subject: [PATCH] fix l class ut --- .gitignore | 3 +- mindnlp/core/__init__.py | 2 + mindnlp/core/_tensor.py | 31 ++++- mindnlp/core/fft/__init__.py | 13 +- mindnlp/core/mps/__init__.py | 11 ++ mindnlp/core/nn/modules/__init__.py | 2 +- mindnlp/core/nn/modules/conv.py | 190 ++++++++++++++-------------- mindnlp/core/ops/array.py | 1 + mindnlp/core/ops/creation.py | 2 +- mindnlp/core/ops/other.py | 4 +- mindnlp/core/overrides.py | 58 ++++++++- mindnlp/core/xpu/__init__.py | 11 ++ 12 files changed, 224 insertions(+), 104 deletions(-) create mode 100644 mindnlp/core/mps/__init__.py create mode 100644 mindnlp/core/xpu/__init__.py diff --git a/.gitignore b/.gitignore index ceb763309..690c0872d 100644 --- a/.gitignore +++ b/.gitignore @@ -166,4 +166,5 @@ xiyouji.txt *.jit flagged/ -huggingface_transformers/ \ No newline at end of file +huggingface_transformers/ +diffusers/ \ No newline at end of file diff --git a/mindnlp/core/__init__.py b/mindnlp/core/__init__.py index 51c8b9488..3b9ae5178 100644 --- a/mindnlp/core/__init__.py +++ b/mindnlp/core/__init__.py @@ -26,6 +26,8 @@ Union as _Union, ) +from mindspore.runtime import Stream + strided = None contiguous_format = None preserve_format = None diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index e2725a954..096a0032f 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -538,11 +538,34 @@ def new_full(self, size, fill_value, *, dtype=None, device=None, requires_grad=F StubTensor.new_full = new_full def new_zeros(self, *size, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False): - return ops.zeros(*size, dtype=dtype if dtype is not None else self.dtype) + if isinstance(size[0], (tuple, list)): + size = size[0] + + new_size = () + for s in size: + if isinstance(s, Tensor): + s = s.item() + new_size += (s,) + return ops.zeros(*new_size, dtype=dtype if dtype is not None else self.dtype) Tensor.new_zeros = new_zeros StubTensor.new_zeros = new_zeros + def new_ones(self, *size, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False, **kwargs): + size = kwargs.get('size', size) + if isinstance(size[0], (tuple, list)): + size = size[0] + + new_size = () + for s in size: + if isinstance(s, Tensor): + s = s.item() + new_size += (s,) + return ops.ones(*new_size, dtype=dtype if dtype is not None else self.dtype) + + Tensor.new_ones = new_ones + StubTensor.new_ones = new_ones + Tensor.sum = ops.sum StubTensor.sum = ops.sum @@ -598,6 +621,12 @@ def __contains__(self, item): Tensor.mean = ops.mean StubTensor.mean = ops.mean + Tensor.amax = ops.amax + StubTensor.amax = ops.amax + + Tensor.as_strided = ops.as_strided + StubTensor.as_strided = ops.as_strided + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) return ret \ No newline at end of file diff --git a/mindnlp/core/fft/__init__.py b/mindnlp/core/fft/__init__.py index c208ed94d..a37d679b1 100644 --- a/mindnlp/core/fft/__init__.py +++ b/mindnlp/core/fft/__init__.py @@ -2,7 +2,7 @@ from mindspore import ops from mindspore.ops._primitive_cache import _get_cache_prim from ..configs import use_pyboost -from ..ops import narrow +from ..ops import narrow, roll from ..nn import functional as F def rfft(input, n=None, dim=-1, norm="backward"): @@ -35,4 +35,13 @@ def fftn(input, s=None, dim=None, norm=None): def fft(input, s=None, dim=-1, norm=None): return ops.fft(input, s, dim, norm) -__all__ = ['fft', 'fftn', 'irfft', 'rfft'] \ No newline at end of file +def fftshift(x, dim=None): + return ops.fftshift(x, dim) + +def ifftn(input, s=None, dim=None, norm=None, *, out=None): + return ops.ifftn(input, s, dim, norm) + +def ifftshift(input, dim=None): + return ops.ifftshift(input, dim) + +__all__ = ['fft', 'fftn', 'irfft', 'rfft'] diff --git a/mindnlp/core/mps/__init__.py b/mindnlp/core/mps/__init__.py new file mode 100644 index 000000000..1ba158d31 --- /dev/null +++ b/mindnlp/core/mps/__init__.py @@ -0,0 +1,11 @@ +def is_available(): + return False + +def empty_cache(): + pass + +def device_count(): + return 0 + +def manual_seed(*args, **kwargs): + pass diff --git a/mindnlp/core/nn/modules/__init__.py b/mindnlp/core/nn/modules/__init__.py index 862358c19..6794de7b3 100644 --- a/mindnlp/core/nn/modules/__init__.py +++ b/mindnlp/core/nn/modules/__init__.py @@ -6,7 +6,7 @@ from .normalization import LayerNorm, GroupNorm, RMSNorm from .dropout import Dropout, Dropout2d from .activation import * -from .conv import Conv3d, Conv2d, Conv1d, ConvTranspose2d, ConvTranspose1d +from .conv import Conv3d, Conv2d, Conv1d, ConvTranspose2d, ConvTranspose1d, ConvTranspose3d from .padding import ZeroPad2d, ConstantPad2d, ConstantPad1d, ConstantPad3d from .batchnorm import BatchNorm2d, BatchNorm1d, SyncBatchNorm from .pooling import AdaptiveAvgPool2d, AvgPool1d, MaxPool2d, MaxPool1d, AdaptiveAvgPool1d, AvgPool2d diff --git a/mindnlp/core/nn/modules/conv.py b/mindnlp/core/nn/modules/conv.py index d188e219d..20666de5b 100644 --- a/mindnlp/core/nn/modules/conv.py +++ b/mindnlp/core/nn/modules/conv.py @@ -666,101 +666,101 @@ def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Ten self.dilation, ) -# class ConvTranspose3d(_ConvTransposeNd): -# r"""Applies a 3D transposed convolution operator over an input image composed of several input -# planes. -# The transposed convolution operator multiplies each input value element-wise by a learnable kernel, -# and sums over the outputs from all input feature planes. - -# This module can be seen as the gradient of Conv3d with respect to its input. -# It is also known as a fractionally-strided convolution or -# a deconvolution (although it is not an actual deconvolution operation). - -# | :attr:`stride` controls the stride for the cross-correlation. -# | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides -# for :attr:`padding` number of points. -# | If :attr:`output_padding` is non-zero, then the output is implicitly zero-padded on one side -# for :attr:`output_padding` number of points. -# | :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. -# It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. -# | :attr:`groups` controls the connections between inputs and outputs. `in_channels` and `out_channels` -# must both be divisible by `groups`. -# | At groups=1, all inputs are convolved to all outputs. -# | At groups=2, the operation becomes equivalent to having two conv layers -# side by side, each seeing half the input channels, -# and producing half the output channels, and both subsequently concatenated. -# At groups=`in_channels`, each input channel is convolved with its own set of filters -# (of size `out_channels // in_channels`). - -# The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` -# can either be: - -# - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions -# - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, -# the second `int` for the height dimension and the third `int` for the width dimension - -# .. note:: - -# Depending of the size of your kernel, several (of the last) -# columns of the input might be lost, because it is a valid `cross-correlation`_, -# and not a full `cross-correlation`_. -# It is up to the user to add proper padding. - -# Args: -# in_channels (int): Number of channels in the input image -# out_channels (int): Number of channels produced by the convolution -# kernel_size (int or tuple): Size of the convolving kernel -# stride (int or tuple, optional): Stride of the convolution -# padding (int or tuple, optional): Zero-padding added to both sides of the input -# output_padding (int or tuple, optional): Zero-padding added to one side of the output -# groups (int, optional): Number of blocked connections from input channels to output channels -# bias (bool, optional): If True, adds a learnable bias to the output -# dilation (int or tuple, optional): Spacing between kernel elements - -# Shape: -# - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` -# - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where -# :math:`D_{out} = (D_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0] + output\_padding[0]` -# :math:`H_{out} = (H_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1] + output\_padding[1]` -# :math:`W_{out} = (W_{in} - 1) * stride[2] - 2 * padding[2] + kernel\_size[2] + output\_padding[2]` - -# Attributes: -# weight (Tensor): the learnable weights of the module of shape -# (in_channels, out_channels, kernel_size[0], kernel_size[1], kernel_size[2]) -# bias (Tensor): the learnable bias of the module of shape (out_channels) - -# Examples:: - -# >>> # With square kernels and equal stride -# >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2) -# >>> # non-square kernels and unequal stride and with padding -# >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2)) -# >>> input = autograd.Variable(core.randn(20, 16, 10, 50, 100)) -# >>> output = m(input) - -# .. _cross-correlation: -# https://en.wikipedia.org/wiki/Cross-correlation - -# .. _link: -# https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md -# """ - -# def __init__(self, in_channels, out_channels, kernel_size, stride=1, -# padding=0, output_padding=0, groups=1, bias=True, dilation=1): -# kernel_size = _triple(kernel_size) -# stride = _triple(stride) -# padding = _triple(padding) -# dilation = _triple(dilation) -# output_padding = _triple(output_padding) -# super(ConvTranspose3d, self).__init__( -# in_channels, out_channels, kernel_size, stride, padding, dilation, -# True, output_padding, groups, bias) - -# def forward(self, input, output_size=None): -# output_padding = self._output_padding(input, output_size) -# return F.conv_transpose3d( -# input, self.weight, self.bias, self.stride, self.padding, -# output_padding, self.groups, self.dilation) +class ConvTranspose3d(_ConvTransposeNd): + r"""Applies a 3D transposed convolution operator over an input image composed of several input + planes. + The transposed convolution operator multiplies each input value element-wise by a learnable kernel, + and sums over the outputs from all input feature planes. + + This module can be seen as the gradient of Conv3d with respect to its input. + It is also known as a fractionally-strided convolution or + a deconvolution (although it is not an actual deconvolution operation). + + | :attr:`stride` controls the stride for the cross-correlation. + | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + for :attr:`padding` number of points. + | If :attr:`output_padding` is non-zero, then the output is implicitly zero-padded on one side + for :attr:`output_padding` number of points. + | :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + | :attr:`groups` controls the connections between inputs and outputs. `in_channels` and `out_channels` + must both be divisible by `groups`. + | At groups=1, all inputs are convolved to all outputs. + | At groups=2, the operation becomes equivalent to having two conv layers + side by side, each seeing half the input channels, + and producing half the output channels, and both subsequently concatenated. + At groups=`in_channels`, each input channel is convolved with its own set of filters + (of size `out_channels // in_channels`). + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` + can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + .. note:: + + Depending of the size of your kernel, several (of the last) + columns of the input might be lost, because it is a valid `cross-correlation`_, + and not a full `cross-correlation`_. + It is up to the user to add proper padding. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution + padding (int or tuple, optional): Zero-padding added to both sides of the input + output_padding (int or tuple, optional): Zero-padding added to one side of the output + groups (int, optional): Number of blocked connections from input channels to output channels + bias (bool, optional): If True, adds a learnable bias to the output + dilation (int or tuple, optional): Spacing between kernel elements + + Shape: + - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where + :math:`D_{out} = (D_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0] + output\_padding[0]` + :math:`H_{out} = (H_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1] + output\_padding[1]` + :math:`W_{out} = (W_{in} - 1) * stride[2] - 2 * padding[2] + kernel\_size[2] + output\_padding[2]` + + Attributes: + weight (Tensor): the learnable weights of the module of shape + (in_channels, out_channels, kernel_size[0], kernel_size[1], kernel_size[2]) + bias (Tensor): the learnable bias of the module of shape (out_channels) + + Examples:: + + >>> # With square kernels and equal stride + >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2)) + >>> input = autograd.Variable(core.randn(20, 16, 10, 50, 100)) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, output_padding=0, groups=1, bias=True, dilation=1): + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + output_padding = _triple(output_padding) + super(ConvTranspose3d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias) + + def forward(self, input, output_size=None): + output_padding = self._output_padding(input, output_size) + return F.conv_transpose3d( + input, self.weight, self.bias, self.stride, self.padding, + output_padding, self.groups, self.dilation) # TODO: Conv2dLocal diff --git a/mindnlp/core/ops/array.py b/mindnlp/core/ops/array.py index bd01b32c2..65d312222 100644 --- a/mindnlp/core/ops/array.py +++ b/mindnlp/core/ops/array.py @@ -120,6 +120,7 @@ def masked_select(input, mask, *, out=None): # narrow has_narrow = hasattr(mindspore.mint, 'narrow') def narrow(input, dim, start, length): + length = length.item() if isinstance(length, mindspore.Tensor) else length if use_pyboost() and has_narrow: return mindspore.mint.narrow(input, dim, start, length) return ops.narrow(input, dim, start, length) diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py index cdbc795f2..bdd6848a1 100644 --- a/mindnlp/core/ops/creation.py +++ b/mindnlp/core/ops/creation.py @@ -191,7 +191,7 @@ def empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=Fal # full has_full = hasattr(mindspore.mint, 'full') -def full(size, fill_value, *, dtype=None, device=None): +def full(size, fill_value, *, dtype=None, device=None, **kwargs): new_size = () for s in size: if isinstance(s, mindspore.Tensor): diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index eb00bdf63..98b8febc7 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -152,8 +152,8 @@ def clone(input): # cumsum has_cumsum = hasattr(mindspore.mint, "cumsum") - -def cumsum(input, dim, dtype=None, out=None): +def cumsum(input, dim=None, dtype=None, out=None, **kwargs): + dim = kwargs.pop('axis', dim) input_dtype = input.dtype if input_dtype == mindspore.int64: input = input.to(mindspore.int32) diff --git a/mindnlp/core/overrides.py b/mindnlp/core/overrides.py index 7808818a1..7ba241dd9 100644 --- a/mindnlp/core/overrides.py +++ b/mindnlp/core/overrides.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable, Iterable, Any from mindnlp import core @@ -82,4 +83,59 @@ def handle_torch_function( pass def has_torch_function(inp): - return hasattr(inp, "__torch_function__") \ No newline at end of file + return hasattr(inp, "__torch_function__") + +class TorchFunctionMode: + """ + A ``TorchFunctionMode`` allows you to override the meaning of all + ``__torch_function__`` overridable functions within a dynamic scope, + without having to actually create a tensor subclass or manually + monkey-patch functions in the PyTorch API. Some common situations + where you should use a mode: + + * You want to override the meaning of factory functions, or other + functions that do not otherwise take a tensor as an argument + (these cannot be overridden with tensor subclasses). + + * You want to override the behavior of all functions without needing + to wrap your inputs in tensor subclasses; e.g., if you are just + interested in logging intermediate computations. + + * You want to control the order of execution of various tensor + subclasses explicitly, rather than implicitly via the return of + ``NotImplemented``. + + Independent subclasses of :class:`TorchFunctionMode` are compositional: + modes can be pushed onto a stack using ``with MyMode():``. + When you call functions in the PyTorch API inside your + ``__torch_function__`` implementation, by default, they will forward on to + the next mode on the mode stack. If you want recursively call back into + your current ``__torch_function__`` implementation, either explicitly + invoke ``self.__torch_function__(...)``, or use the context manager + ``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch + API self-referential (beware of infinite loops, in this case!) + """ + + inner: "TorchFunctionMode" + + # Force metaclass to generate constructor at the base of the hierarchy + def __init__(self) -> None: + pass + + def __torch_function__(self, func, types, args=(), kwargs=None): + raise NotImplementedError + + def __enter__(self): + _push_mode(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _pop_mode() + + @classmethod + def push(cls, *args, **kwargs): + warnings.warn( + "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`" + ) + instance = cls(*args, **kwargs) + return instance diff --git a/mindnlp/core/xpu/__init__.py b/mindnlp/core/xpu/__init__.py new file mode 100644 index 000000000..1ba158d31 --- /dev/null +++ b/mindnlp/core/xpu/__init__.py @@ -0,0 +1,11 @@ +def is_available(): + return False + +def empty_cache(): + pass + +def device_count(): + return 0 + +def manual_seed(*args, **kwargs): + pass