Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mindnlp/core/_dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@
# substitute_in_graph,
)

from . import eval_frame
from . import eval_frame

def reset():
pass
13 changes: 13 additions & 0 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class StubTensor: pass
mindspore.int16: 2,
mindspore.bfloat16: 2,
mindspore.float16: 2,
mindspore.bool_: 1
}

DEVICE_MAP = {
Expand Down Expand Up @@ -450,6 +451,18 @@ def clamp_min(self, value):
Tensor.index_copy_ = ops.inplace_index_copy
StubTensor.index_copy_ = ops.inplace_index_copy

Tensor.max = ops.max
StubTensor.max = ops.max

Tensor.min = ops.min
StubTensor.min = ops.min

Tensor.squeeze_ = ops.inplace_squeeze
StubTensor.squeeze_ = ops.inplace_squeeze

Tensor.unsqueeze_ = ops.inplace_unsqueeze
StubTensor.unsqueeze_ = ops.inplace_unsqueeze


def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
Expand Down
5 changes: 5 additions & 0 deletions mindnlp/core/nn/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class SDPBackend:
pass

def sdpa_kernel(*args, **kwargs):
pass
10 changes: 8 additions & 2 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def pad(input, pad, mode='constant', value=0.0):
new_pad = ()
for idx, pad_v in enumerate(pad):
if pad_v < 0:
dim = idx // 2
dim = input.ndim - 1 - idx // 2
input = input.narrow(dim, 0, input.shape[dim] + pad_v)
pad_v = 0
new_pad += (pad_v,)
Expand Down Expand Up @@ -530,7 +530,7 @@ def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
pad_mode = padding
pad = (0,) * 4

_conv2d = _get_cache_prim(ops.Conv2D)(out_channel=weight.shape[0] * groups,
_conv2d = _get_cache_prim(ops.Conv2D)(out_channel=weight.shape[0],
kernel_size=(1, weight.shape[-1]),
mode=1,
pad_mode=pad_mode,
Expand Down Expand Up @@ -593,6 +593,12 @@ def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
"""
raise ValueError("Requires mindspore >= 2.3.0 by default, or set into pyboost mode by calling torch.config.set_byboost(True).")

def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
return mint.nn.functional.conv_transpose1d(input, weight, bias, stride, padding, output_padding, groups, dilation)


def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation)

def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
if use_pyboost():
Expand Down
211 changes: 117 additions & 94 deletions mindnlp/core/nn/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,11 @@ def forward(self, input: Tensor) -> Tensor:
class _ConvTransposeNd(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode, dtype=None) -> None:
groups, bias, padding_mode, dtype=None, device=None) -> None:
if padding_mode != 'zeros':
raise ValueError(f'Only "zeros" padding mode is supported for {self.__class__.__name__}')

factory_kwargs = {'dtype': dtype}
factory_kwargs = {'dtype': dtype, 'device': device}
super().__init__(
in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
Expand Down Expand Up @@ -426,62 +426,71 @@ class ConvTranspose1d(_ConvTransposeNd):
bias (Tensor): the learnable bias of the module of shape (out_channels)
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode: str = 'zeros'):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
output_padding: _size_1_t = 0,
groups: int = 1,
bias: bool = True,
dilation: _size_1_t = 1,
padding_mode: str = "zeros",
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
output_padding = _single(output_padding)
super(ConvTranspose1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode)

pad_mode = 'pad'
pad = padding
if isinstance(padding, tuple):
pad = (0, 0, padding[0], padding[0])
elif isinstance(padding, int):
pad = (0, 0) + (padding,) * 2
if not isinstance(padding, (int, tuple)):
pad_mode = padding
pad = (0,) * 4

# cause Conv2DTranspose's out_channel refers to Conv2D's out_channel.
self.conv2d_transpose = mops.Conv2DTranspose(out_channel=self.out_channels,
kernel_size=(1,) + self.kernel_size,
mode=1,
pad_mode=pad_mode,
pad=pad,
stride=(1,) + self.stride,
dilation=(1,) + self.dilation,
group=self.groups)
self.h_add = _deconv_output_length(pad_mode, 1, 1, 1, pad[0] + pad[1])
self.w_add = _deconv_output_length(pad_mode, kernel_size[0], stride[0], dilation[0], pad[2] + pad[3])

def forward(self, input, output_size=None):
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
True,
output_padding,
groups,
bias,
padding_mode,
**factory_kwargs,
)

def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor:
if self.padding_mode != "zeros":
raise ValueError(
"Only `zeros` padding mode is supported for ConvTranspose1d"
)

assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
input = mops.expand_dims(input, 2)
n, _, h, w = input.shape
conv2d_trans_ret = self.conv2d_transpose(input, self.weight.expand_dims(2),
(n, self.out_channels,
h + self.h_add,
w * self.stride[0] + self.w_add))
if self.bias is not None:
conv2d_trans_ret = mops.bias_add(conv2d_trans_ret, self.bias)

conv2d_trans_ret = conv2d_trans_ret.squeeze(2)
conv2d_trans_ret = ops.pad(conv2d_trans_ret, (0,) + output_padding, value=0.)
return conv2d_trans_ret
input,
output_size,
self.stride, # type: ignore[arg-type]
self.padding, # type: ignore[arg-type]
self.kernel_size, # type: ignore[arg-type]
num_spatial_dims,
self.dilation, # type: ignore[arg-type]
)
return F.conv_transpose1d(
input,
self.weight,
self.bias,
self.stride,
self.padding,
output_padding,
self.groups,
self.dilation,
)


def _deconv_output_length(pad_mode, filter_size, stride_size, dilation_size, padding):
Expand Down Expand Up @@ -582,66 +591,80 @@ class ConvTranspose2d(_ConvTransposeNd):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True, dilation=1,
padding_mode='zeros', dtype=None):
factory_kwargs = {'dtype': dtype}
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: _size_2_t = 0,
output_padding: _size_2_t = 0,
groups: int = 1,
bias: bool = True,
dilation: _size_2_t = 1,
padding_mode: str = "zeros",
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
output_padding = _pair(output_padding)
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode, **factory_kwargs)

pad_mode = 'pad'
pad = padding
if isinstance(padding, tuple):
pad = (padding[0], padding[0], padding[1], padding[1])
elif isinstance(padding, int):
pad = (padding,) * 4
if not isinstance(padding, (int, tuple)):
pad_mode = padding
pad = (0,) * 4

# cause Conv2DTranspose's out_channel refers to Conv2D's out_channel.
self.conv2d_transpose = mops.Conv2DTranspose(out_channel=in_channels,
kernel_size=kernel_size,
mode=1,
pad_mode=pad_mode,
pad=pad,
stride=stride,
dilation=dilation,
group=groups)

self.h_add = _deconv_output_length(pad_mode, kernel_size[0], stride[0], dilation[0], pad[0] + pad[1])
self.w_add = _deconv_output_length(pad_mode, kernel_size[1], stride[1], dilation[1], pad[2] + pad[3])

def forward(self, input, output_size=None):
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
True,
output_padding,
groups,
bias,
padding_mode,
**factory_kwargs,
)

def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor:
"""
Performs the forward pass.

Attributes:
input (Tensor): The input tensor.
output_size (list[int], optional): A list of integers representing
the size of the output tensor. Default is None.
"""
if self.padding_mode != "zeros":
raise ValueError(
"Only `zeros` padding mode is supported for ConvTranspose2d"
)

assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 2
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]

n, _, h, w = input.shape
conv2d_trans_ret = self.conv2d_transpose(input, self.weight,
(n, self.out_channels,
h * self.stride[0] + self.h_add,
w * self.stride[1] + self.w_add))
if self.bias is not None:
conv2d_trans_ret = mops.bias_add(conv2d_trans_ret, self.bias)

conv2d_trans_ret = ops.pad(conv2d_trans_ret, output_padding, value=0.)

return conv2d_trans_ret
input,
output_size,
self.stride, # type: ignore[arg-type]
self.padding, # type: ignore[arg-type]
self.kernel_size, # type: ignore[arg-type]
num_spatial_dims,
self.dilation, # type: ignore[arg-type]
)

return F.conv_transpose2d(
input,
self.weight,
self.bias,
self.stride,
self.padding,
output_padding,
self.groups,
self.dilation,
)

# class ConvTranspose3d(_ConvTransposeNd):
# r"""Applies a 3D transposed convolution operator over an input image composed of several input
Expand Down
15 changes: 9 additions & 6 deletions mindnlp/core/nn/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
"""RNN operators module, include RNN, GRU."""
import math
import warnings
from mindspore import context
from mindspore import ops as P

from ..parameter import Parameter
from mindnlp import core
from .module import Module
from .dropout import Dropout
from ... import ops
from ..parameter import Parameter
from .. import init
from ... import ops


__all__ = ['LSTM', 'GRU', 'RNN']
Expand Down Expand Up @@ -272,7 +275,7 @@ class _DynamicLSTMAscend(Module):

def __init__(self):
super().__init__()
self.lstm = DynamicRNN()
self.lstm = P.DynamicRNN()

def forward(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
'''Dynamic LSTM module on Ascend'''
Expand Down Expand Up @@ -324,7 +327,7 @@ def __init__(self, mode, input_size, hidden_size, num_layers=1, bias=True,
"recurrent layer, so non-zero dropout expects "
"num_layers greater than 1, but got dropout={} and "
"num_layers={}".format(dropout, num_layers))

is_ascend = context.get_context("device_target") == "Ascend"
if mode == "LSTM":
gate_size = 4 * hidden_size
self.rnn = _DynamicLSTMAscend() if is_ascend else _DynamicLSTMCPUGPU()
Expand All @@ -344,8 +347,8 @@ def __init__(self, mode, input_size, hidden_size, num_layers=1, bias=True,
raise ValueError(f"For '{self.cls_name}', the 'mode' must be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], "
f"but got {mode}.")

self.reverse = ReverseV2([0])
self.reverse_sequence = ReverseSequence(0, 1)
self.reverse = P.ReverseV2([0])
self.reverse_sequence = P.ReverseSequence(0, 1)
self.hidden_size = hidden_size
self.batch_first = batch_first
self.num_layers = num_layers
Expand Down
Loading
Loading