diff --git a/mindnlp/core/_dynamo/__init__.py b/mindnlp/core/_dynamo/__init__.py index b72e1e38e..03a208a13 100644 --- a/mindnlp/core/_dynamo/__init__.py +++ b/mindnlp/core/_dynamo/__init__.py @@ -19,4 +19,7 @@ # substitute_in_graph, ) -from . import eval_frame \ No newline at end of file +from . import eval_frame + +def reset(): + pass diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index fdb02b147..5c4412d18 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -30,6 +30,7 @@ class StubTensor: pass mindspore.int16: 2, mindspore.bfloat16: 2, mindspore.float16: 2, + mindspore.bool_: 1 } DEVICE_MAP = { @@ -450,6 +451,18 @@ def clamp_min(self, value): Tensor.index_copy_ = ops.inplace_index_copy StubTensor.index_copy_ = ops.inplace_index_copy + Tensor.max = ops.max + StubTensor.max = ops.max + + Tensor.min = ops.min + StubTensor.min = ops.min + + Tensor.squeeze_ = ops.inplace_squeeze + StubTensor.squeeze_ = ops.inplace_squeeze + + Tensor.unsqueeze_ = ops.inplace_unsqueeze + StubTensor.unsqueeze_ = ops.inplace_unsqueeze + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) diff --git a/mindnlp/core/nn/attention/__init__.py b/mindnlp/core/nn/attention/__init__.py index e69de29bb..7d0e25b2d 100644 --- a/mindnlp/core/nn/attention/__init__.py +++ b/mindnlp/core/nn/attention/__init__.py @@ -0,0 +1,5 @@ +class SDPBackend: + pass + +def sdpa_kernel(*args, **kwargs): + pass diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index c8a6b0378..657895647 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -300,7 +300,7 @@ def pad(input, pad, mode='constant', value=0.0): new_pad = () for idx, pad_v in enumerate(pad): if pad_v < 0: - dim = idx // 2 + dim = input.ndim - 1 - idx // 2 input = input.narrow(dim, 0, input.shape[dim] + pad_v) pad_v = 0 new_pad += (pad_v,) @@ -530,7 +530,7 @@ def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): pad_mode = padding pad = (0,) * 4 - _conv2d = _get_cache_prim(ops.Conv2D)(out_channel=weight.shape[0] * groups, + _conv2d = _get_cache_prim(ops.Conv2D)(out_channel=weight.shape[0], kernel_size=(1, weight.shape[-1]), mode=1, pad_mode=pad_mode, @@ -593,6 +593,12 @@ def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): """ raise ValueError("Requires mindspore >= 2.3.0 by default, or set into pyboost mode by calling torch.config.set_byboost(True).") +def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + return mint.nn.functional.conv_transpose1d(input, weight, bias, stride, padding, output_padding, groups, dilation) + + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation) def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): if use_pyboost(): diff --git a/mindnlp/core/nn/modules/conv.py b/mindnlp/core/nn/modules/conv.py index efdef4664..d188e219d 100644 --- a/mindnlp/core/nn/modules/conv.py +++ b/mindnlp/core/nn/modules/conv.py @@ -322,11 +322,11 @@ def forward(self, input: Tensor) -> Tensor: class _ConvTransposeNd(_ConvNd): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, - groups, bias, padding_mode, dtype=None) -> None: + groups, bias, padding_mode, dtype=None, device=None) -> None: if padding_mode != 'zeros': raise ValueError(f'Only "zeros" padding mode is supported for {self.__class__.__name__}') - factory_kwargs = {'dtype': dtype} + factory_kwargs = {'dtype': dtype, 'device': device} super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, @@ -426,62 +426,71 @@ class ConvTranspose1d(_ConvTransposeNd): bias (Tensor): the learnable bias of the module of shape (out_channels) """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode: str = 'zeros'): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} kernel_size = _single(kernel_size) stride = _single(stride) padding = _single(padding) dilation = _single(dilation) output_padding = _single(output_padding) - super(ConvTranspose1d, self).__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - True, output_padding, groups, bias, padding_mode) - - pad_mode = 'pad' - pad = padding - if isinstance(padding, tuple): - pad = (0, 0, padding[0], padding[0]) - elif isinstance(padding, int): - pad = (0, 0) + (padding,) * 2 - if not isinstance(padding, (int, tuple)): - pad_mode = padding - pad = (0,) * 4 - - # cause Conv2DTranspose's out_channel refers to Conv2D's out_channel. - self.conv2d_transpose = mops.Conv2DTranspose(out_channel=self.out_channels, - kernel_size=(1,) + self.kernel_size, - mode=1, - pad_mode=pad_mode, - pad=pad, - stride=(1,) + self.stride, - dilation=(1,) + self.dilation, - group=self.groups) - self.h_add = _deconv_output_length(pad_mode, 1, 1, 1, pad[0] + pad[1]) - self.w_add = _deconv_output_length(pad_mode, kernel_size[0], stride[0], dilation[0], pad[2] + pad[3]) - - def forward(self, input, output_size=None): - if self.padding_mode != 'zeros': - raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d') + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose1d" + ) assert isinstance(self.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" because # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. num_spatial_dims = 1 output_padding = self._output_padding( - input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] - num_spatial_dims, self.dilation) # type: ignore[arg-type] - input = mops.expand_dims(input, 2) - n, _, h, w = input.shape - conv2d_trans_ret = self.conv2d_transpose(input, self.weight.expand_dims(2), - (n, self.out_channels, - h + self.h_add, - w * self.stride[0] + self.w_add)) - if self.bias is not None: - conv2d_trans_ret = mops.bias_add(conv2d_trans_ret, self.bias) - - conv2d_trans_ret = conv2d_trans_ret.squeeze(2) - conv2d_trans_ret = ops.pad(conv2d_trans_ret, (0,) + output_padding, value=0.) - return conv2d_trans_ret + input, + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, # type: ignore[arg-type] + ) + return F.conv_transpose1d( + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) def _deconv_output_length(pad_mode, filter_size, stride_size, dilation_size, padding): @@ -582,66 +591,80 @@ class ConvTranspose2d(_ConvTransposeNd): https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, output_padding=0, groups=1, bias=True, dilation=1, - padding_mode='zeros', dtype=None): - factory_kwargs = {'dtype': dtype} + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + output_padding: _size_2_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_2_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) output_padding = _pair(output_padding) super().__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - True, output_padding, groups, bias, padding_mode, **factory_kwargs) - - pad_mode = 'pad' - pad = padding - if isinstance(padding, tuple): - pad = (padding[0], padding[0], padding[1], padding[1]) - elif isinstance(padding, int): - pad = (padding,) * 4 - if not isinstance(padding, (int, tuple)): - pad_mode = padding - pad = (0,) * 4 - - # cause Conv2DTranspose's out_channel refers to Conv2D's out_channel. - self.conv2d_transpose = mops.Conv2DTranspose(out_channel=in_channels, - kernel_size=kernel_size, - mode=1, - pad_mode=pad_mode, - pad=pad, - stride=stride, - dilation=dilation, - group=groups) - - self.h_add = _deconv_output_length(pad_mode, kernel_size[0], stride[0], dilation[0], pad[0] + pad[1]) - self.w_add = _deconv_output_length(pad_mode, kernel_size[1], stride[1], dilation[1], pad[2] + pad[3]) - - def forward(self, input, output_size=None): - if self.padding_mode != 'zeros': - raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d') + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: + """ + Performs the forward pass. + + Attributes: + input (Tensor): The input tensor. + output_size (list[int], optional): A list of integers representing + the size of the output tensor. Default is None. + """ + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose2d" + ) assert isinstance(self.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" because # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. num_spatial_dims = 2 output_padding = self._output_padding( - input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] - num_spatial_dims, self.dilation) # type: ignore[arg-type] - - n, _, h, w = input.shape - conv2d_trans_ret = self.conv2d_transpose(input, self.weight, - (n, self.out_channels, - h * self.stride[0] + self.h_add, - w * self.stride[1] + self.w_add)) - if self.bias is not None: - conv2d_trans_ret = mops.bias_add(conv2d_trans_ret, self.bias) - - conv2d_trans_ret = ops.pad(conv2d_trans_ret, output_padding, value=0.) - - return conv2d_trans_ret + input, + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, # type: ignore[arg-type] + ) + return F.conv_transpose2d( + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) # class ConvTranspose3d(_ConvTransposeNd): # r"""Applies a 3D transposed convolution operator over an input image composed of several input diff --git a/mindnlp/core/nn/modules/rnn.py b/mindnlp/core/nn/modules/rnn.py index 5c01c05f1..99a2a60b6 100644 --- a/mindnlp/core/nn/modules/rnn.py +++ b/mindnlp/core/nn/modules/rnn.py @@ -15,12 +15,15 @@ """RNN operators module, include RNN, GRU.""" import math import warnings +from mindspore import context +from mindspore import ops as P -from ..parameter import Parameter +from mindnlp import core from .module import Module from .dropout import Dropout -from ... import ops +from ..parameter import Parameter from .. import init +from ... import ops __all__ = ['LSTM', 'GRU', 'RNN'] @@ -272,7 +275,7 @@ class _DynamicLSTMAscend(Module): def __init__(self): super().__init__() - self.lstm = DynamicRNN() + self.lstm = P.DynamicRNN() def forward(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh): '''Dynamic LSTM module on Ascend''' @@ -324,7 +327,7 @@ def __init__(self, mode, input_size, hidden_size, num_layers=1, bias=True, "recurrent layer, so non-zero dropout expects " "num_layers greater than 1, but got dropout={} and " "num_layers={}".format(dropout, num_layers)) - + is_ascend = context.get_context("device_target") == "Ascend" if mode == "LSTM": gate_size = 4 * hidden_size self.rnn = _DynamicLSTMAscend() if is_ascend else _DynamicLSTMCPUGPU() @@ -344,8 +347,8 @@ def __init__(self, mode, input_size, hidden_size, num_layers=1, bias=True, raise ValueError(f"For '{self.cls_name}', the 'mode' must be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], " f"but got {mode}.") - self.reverse = ReverseV2([0]) - self.reverse_sequence = ReverseSequence(0, 1) + self.reverse = P.ReverseV2([0]) + self.reverse_sequence = P.ReverseSequence(0, 1) self.hidden_size = hidden_size self.batch_first = batch_first self.num_layers = num_layers diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py index 502a907a7..b87e66e0b 100644 --- a/mindnlp/core/ops/creation.py +++ b/mindnlp/core/ops/creation.py @@ -6,10 +6,13 @@ from mindspore._c_expression import Tensor as CTensor # pylint: disable=no-name-in-module, import-error except: from mindspore._c_expression import TensorPy as CTensor # pylint: disable=no-name-in-module, import-error + +from mindspore._c_expression.typing import Type from mindspore import ops from mindspore.ops._primitive_cache import _get_cache_prim from ..configs import use_pyboost, ON_ORANGE_PI from .._bind import get_default_dtype, get_default_device +from .utils import py2dtype def as_strided(self, size, stride, storage_offset=None): if len(size) != len(stride): @@ -56,7 +59,7 @@ def zeros(*size, dtype=None, device=None, requires_grad=False, **kwargs): # zeros_like has_zeros_like = hasattr(mindspore.mint, 'zeros_like') -def zeros_like(input, *, dtype=None, memory_format=None): +def zeros_like(input, *, dtype=None, memory_format=None, **kwargs): if dtype is None: dtype = input.dtype if use_pyboost() and has_zeros_like: @@ -71,8 +74,8 @@ def ones(*size, dtype=None, device=None): size = size[0] if dtype is None: dtype = get_default_dtype() - if dtype == bool: - dtype = mindspore.bool_ + if not isinstance(dtype, Type): + dtype = py2dtype[dtype] if use_pyboost() and has_ones: return mindspore.mint.ones(size, dtype=dtype) return _ones(size, dtype) diff --git a/mindnlp/core/ops/inplace.py b/mindnlp/core/ops/inplace.py index d67044417..0dc00ca60 100644 --- a/mindnlp/core/ops/inplace.py +++ b/mindnlp/core/ops/inplace.py @@ -5,6 +5,8 @@ from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op from mindnlp import core +from ..configs import use_pyboost +from ._inner import assign generator_step_ = 12 @@ -97,6 +99,25 @@ def inplace_index_add(input, dim, index, source): _inplace = _get_cache_prim(ops.InplaceIndexAdd)(dim) return _inplace(input, index, source) +has_squeeze = hasattr(mindspore.mint, 'squeeze') +def inplace_squeeze(input, *dim, **kwargs): + dim = kwargs.get('dim', dim) + if use_pyboost() and has_squeeze: + out = mindspore.mint.squeeze(input, dim) + else: + out = ops.squeeze(input, dim) + input.assign_value(out) + return input + + +has_unsqueeze = hasattr(mindspore.mint, 'unsqueeze') +def inplace_unsqueeze(input, dim=None): + if use_pyboost() and has_unsqueeze: + out = mindspore.mint.unsqueeze(input, dim) + out = ops.expand_dims(input, dim) + input.assign_value(out) + return input + __all__ = [ 'inplace_copy', 'inplace_zero', @@ -106,5 +127,7 @@ def inplace_index_add(input, dim, index, source): 'inplace_add', 'inplace_scatter', 'inplace_index_copy', - 'inplace_index_add' + 'inplace_index_add', + 'inplace_squeeze', + 'inplace_unsqueeze' ] diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index ee760fc6b..7c02b9d44 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -390,6 +390,8 @@ def einsum(equation, *operands): AssertionError: If more operands are provided than specified in the equation. RuntimeError: If operands do not broadcast with remapped shapes [original->remapped]. """ + if isinstance(operands[0], (tuple, list)): + operands = operands[0] if use_pyboost() and has_einsum: return mindspore.mint.einsum(equation, *operands) assert operands, "einsum(): must provide at least one operand" diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py index 4f8c3c629..808d62b6f 100644 --- a/mindnlp/core/ops/reduction.py +++ b/mindnlp/core/ops/reduction.py @@ -8,6 +8,8 @@ from ._inner import call_ms_func max_out = namedtuple('max_out', ['values', 'indices']) +min_out = namedtuple('min_out', ['values', 'indices']) + # argmax has_argmax = hasattr(mindspore.mint, 'argmax') def argmax(input, dim=None, keepdim=False): @@ -77,8 +79,10 @@ def max(*args, **kwargs): # min has_min = hasattr(mindspore.mint, 'min') def min(*args, **kwargs): - return mindspore.mint.min(*args, **kwargs) - + out = mindspore.mint.min(*args, **kwargs) + if isinstance(out, tuple): + return min_out(values=out[0], indices=out[1]) + return out # dist diff --git a/mindnlp/core/ops/utils.py b/mindnlp/core/ops/utils.py index a176e28e1..72bcd6b26 100644 --- a/mindnlp/core/ops/utils.py +++ b/mindnlp/core/ops/utils.py @@ -1,3 +1,5 @@ +from .. import _dtype + def sum_to(x, shape): """Sum elements along axes to output an array of a given shape. @@ -19,3 +21,9 @@ def sum_to(x, shape): if lead > 0: y = y.squeeze(lead_axis) return y + +py2dtype = { + bool: _dtype.bool, + int: _dtype.int64, + float: _dtype.float32, +} diff --git a/tests/run_test.py b/tests/run_test.py index 3beb5792a..db28ec331 100644 --- a/tests/run_test.py +++ b/tests/run_test.py @@ -26,8 +26,10 @@ def run_tests(): "and not with_static_cache " \ "and not compile " \ "and not compilation " \ - "and not torchscript " + "and not torchscript " \ + "and not torch_fx" + pytest_args.extend(["--ignore-glob=test_modeling_flax_*.py"]) pytest_args.extend(['-k', skip_ut]) if not pytest_args: print("未提供参数,默认运行当前目录下所有测试")