Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,5 @@ xiyouji.txt
*.jit
flagged/

huggingface_transformers/
huggingface_transformers/
diffusers/
2 changes: 2 additions & 0 deletions mindnlp/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
Union as _Union,
)

from mindspore.runtime import Stream

strided = None
contiguous_format = None
preserve_format = None
Expand Down
31 changes: 30 additions & 1 deletion mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,34 @@ def new_full(self, size, fill_value, *, dtype=None, device=None, requires_grad=F
StubTensor.new_full = new_full

def new_zeros(self, *size, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False):
return ops.zeros(*size, dtype=dtype if dtype is not None else self.dtype)
if isinstance(size[0], (tuple, list)):
size = size[0]

new_size = ()
for s in size:
if isinstance(s, Tensor):
s = s.item()
new_size += (s,)
return ops.zeros(*new_size, dtype=dtype if dtype is not None else self.dtype)

Tensor.new_zeros = new_zeros
StubTensor.new_zeros = new_zeros

def new_ones(self, *size, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False, **kwargs):
size = kwargs.get('size', size)
if isinstance(size[0], (tuple, list)):
size = size[0]

new_size = ()
for s in size:
if isinstance(s, Tensor):
s = s.item()
new_size += (s,)
return ops.ones(*new_size, dtype=dtype if dtype is not None else self.dtype)

Tensor.new_ones = new_ones
StubTensor.new_ones = new_ones

Tensor.sum = ops.sum
StubTensor.sum = ops.sum

Expand Down Expand Up @@ -598,6 +621,12 @@ def __contains__(self, item):
Tensor.mean = ops.mean
StubTensor.mean = ops.mean

Tensor.amax = ops.amax
StubTensor.amax = ops.amax

Tensor.as_strided = ops.as_strided
StubTensor.as_strided = ops.as_strided

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
13 changes: 11 additions & 2 deletions mindnlp/core/fft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim
from ..configs import use_pyboost
from ..ops import narrow
from ..ops import narrow, roll
from ..nn import functional as F

def rfft(input, n=None, dim=-1, norm="backward"):
Expand Down Expand Up @@ -35,4 +35,13 @@ def fftn(input, s=None, dim=None, norm=None):
def fft(input, s=None, dim=-1, norm=None):
return ops.fft(input, s, dim, norm)

__all__ = ['fft', 'fftn', 'irfft', 'rfft']
def fftshift(x, dim=None):
return ops.fftshift(x, dim)

def ifftn(input, s=None, dim=None, norm=None, *, out=None):
return ops.ifftn(input, s, dim, norm)

def ifftshift(input, dim=None):
return ops.ifftshift(input, dim)

__all__ = ['fft', 'fftn', 'irfft', 'rfft']
11 changes: 11 additions & 0 deletions mindnlp/core/mps/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
def is_available():
return False

def empty_cache():
pass

def device_count():
return 0

def manual_seed(*args, **kwargs):
pass
2 changes: 1 addition & 1 deletion mindnlp/core/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .normalization import LayerNorm, GroupNorm, RMSNorm
from .dropout import Dropout, Dropout2d
from .activation import *
from .conv import Conv3d, Conv2d, Conv1d, ConvTranspose2d, ConvTranspose1d
from .conv import Conv3d, Conv2d, Conv1d, ConvTranspose2d, ConvTranspose1d, ConvTranspose3d
from .padding import ZeroPad2d, ConstantPad2d, ConstantPad1d, ConstantPad3d
from .batchnorm import BatchNorm2d, BatchNorm1d, SyncBatchNorm
from .pooling import AdaptiveAvgPool2d, AvgPool1d, MaxPool2d, MaxPool1d, AdaptiveAvgPool1d, AvgPool2d
Expand Down
190 changes: 95 additions & 95 deletions mindnlp/core/nn/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,101 +666,101 @@ def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Ten
self.dilation,
)

# class ConvTranspose3d(_ConvTransposeNd):
# r"""Applies a 3D transposed convolution operator over an input image composed of several input
# planes.
# The transposed convolution operator multiplies each input value element-wise by a learnable kernel,
# and sums over the outputs from all input feature planes.

# This module can be seen as the gradient of Conv3d with respect to its input.
# It is also known as a fractionally-strided convolution or
# a deconvolution (although it is not an actual deconvolution operation).

# | :attr:`stride` controls the stride for the cross-correlation.
# | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
# for :attr:`padding` number of points.
# | If :attr:`output_padding` is non-zero, then the output is implicitly zero-padded on one side
# for :attr:`output_padding` number of points.
# | :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.
# It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
# | :attr:`groups` controls the connections between inputs and outputs. `in_channels` and `out_channels`
# must both be divisible by `groups`.
# | At groups=1, all inputs are convolved to all outputs.
# | At groups=2, the operation becomes equivalent to having two conv layers
# side by side, each seeing half the input channels,
# and producing half the output channels, and both subsequently concatenated.
# At groups=`in_channels`, each input channel is convolved with its own set of filters
# (of size `out_channels // in_channels`).

# The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
# can either be:

# - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions
# - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
# the second `int` for the height dimension and the third `int` for the width dimension

# .. note::

# Depending of the size of your kernel, several (of the last)
# columns of the input might be lost, because it is a valid `cross-correlation`_,
# and not a full `cross-correlation`_.
# It is up to the user to add proper padding.

# Args:
# in_channels (int): Number of channels in the input image
# out_channels (int): Number of channels produced by the convolution
# kernel_size (int or tuple): Size of the convolving kernel
# stride (int or tuple, optional): Stride of the convolution
# padding (int or tuple, optional): Zero-padding added to both sides of the input
# output_padding (int or tuple, optional): Zero-padding added to one side of the output
# groups (int, optional): Number of blocked connections from input channels to output channels
# bias (bool, optional): If True, adds a learnable bias to the output
# dilation (int or tuple, optional): Spacing between kernel elements

# Shape:
# - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
# - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
# :math:`D_{out} = (D_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0] + output\_padding[0]`
# :math:`H_{out} = (H_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1] + output\_padding[1]`
# :math:`W_{out} = (W_{in} - 1) * stride[2] - 2 * padding[2] + kernel\_size[2] + output\_padding[2]`

# Attributes:
# weight (Tensor): the learnable weights of the module of shape
# (in_channels, out_channels, kernel_size[0], kernel_size[1], kernel_size[2])
# bias (Tensor): the learnable bias of the module of shape (out_channels)

# Examples::

# >>> # With square kernels and equal stride
# >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2)
# >>> # non-square kernels and unequal stride and with padding
# >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2))
# >>> input = autograd.Variable(core.randn(20, 16, 10, 50, 100))
# >>> output = m(input)

# .. _cross-correlation:
# https://en.wikipedia.org/wiki/Cross-correlation

# .. _link:
# https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
# """

# def __init__(self, in_channels, out_channels, kernel_size, stride=1,
# padding=0, output_padding=0, groups=1, bias=True, dilation=1):
# kernel_size = _triple(kernel_size)
# stride = _triple(stride)
# padding = _triple(padding)
# dilation = _triple(dilation)
# output_padding = _triple(output_padding)
# super(ConvTranspose3d, self).__init__(
# in_channels, out_channels, kernel_size, stride, padding, dilation,
# True, output_padding, groups, bias)

# def forward(self, input, output_size=None):
# output_padding = self._output_padding(input, output_size)
# return F.conv_transpose3d(
# input, self.weight, self.bias, self.stride, self.padding,
# output_padding, self.groups, self.dilation)
class ConvTranspose3d(_ConvTransposeNd):
r"""Applies a 3D transposed convolution operator over an input image composed of several input
planes.
The transposed convolution operator multiplies each input value element-wise by a learnable kernel,
and sums over the outputs from all input feature planes.

This module can be seen as the gradient of Conv3d with respect to its input.
It is also known as a fractionally-strided convolution or
a deconvolution (although it is not an actual deconvolution operation).

| :attr:`stride` controls the stride for the cross-correlation.
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
for :attr:`padding` number of points.
| If :attr:`output_padding` is non-zero, then the output is implicitly zero-padded on one side
for :attr:`output_padding` number of points.
| :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
| :attr:`groups` controls the connections between inputs and outputs. `in_channels` and `out_channels`
must both be divisible by `groups`.
| At groups=1, all inputs are convolved to all outputs.
| At groups=2, the operation becomes equivalent to having two conv layers
side by side, each seeing half the input channels,
and producing half the output channels, and both subsequently concatenated.
At groups=`in_channels`, each input channel is convolved with its own set of filters
(of size `out_channels // in_channels`).

The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
can either be:

- a single ``int`` -- in which case the same value is used for the depth, height and width dimensions
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
the second `int` for the height dimension and the third `int` for the width dimension

.. note::

Depending of the size of your kernel, several (of the last)
columns of the input might be lost, because it is a valid `cross-correlation`_,
and not a full `cross-correlation`_.
It is up to the user to add proper padding.

Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution
padding (int or tuple, optional): Zero-padding added to both sides of the input
output_padding (int or tuple, optional): Zero-padding added to one side of the output
groups (int, optional): Number of blocked connections from input channels to output channels
bias (bool, optional): If True, adds a learnable bias to the output
dilation (int or tuple, optional): Spacing between kernel elements

Shape:
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
:math:`D_{out} = (D_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0] + output\_padding[0]`
:math:`H_{out} = (H_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1] + output\_padding[1]`
:math:`W_{out} = (W_{in} - 1) * stride[2] - 2 * padding[2] + kernel\_size[2] + output\_padding[2]`

Attributes:
weight (Tensor): the learnable weights of the module of shape
(in_channels, out_channels, kernel_size[0], kernel_size[1], kernel_size[2])
bias (Tensor): the learnable bias of the module of shape (out_channels)

Examples::

>>> # With square kernels and equal stride
>>> m = nn.ConvTranspose3d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2))
>>> input = autograd.Variable(core.randn(20, 16, 10, 50, 100))
>>> output = m(input)

.. _cross-correlation:
https://en.wikipedia.org/wiki/Cross-correlation

.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True, dilation=1):
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
output_padding = _triple(output_padding)
super(ConvTranspose3d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias)

def forward(self, input, output_size=None):
output_padding = self._output_padding(input, output_size)
return F.conv_transpose3d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)


# TODO: Conv2dLocal
Expand Down
1 change: 1 addition & 0 deletions mindnlp/core/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def masked_select(input, mask, *, out=None):
# narrow
has_narrow = hasattr(mindspore.mint, 'narrow')
def narrow(input, dim, start, length):
length = length.item() if isinstance(length, mindspore.Tensor) else length
if use_pyboost() and has_narrow:
return mindspore.mint.narrow(input, dim, start, length)
return ops.narrow(input, dim, start, length)
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/core/ops/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=Fal

# full
has_full = hasattr(mindspore.mint, 'full')
def full(size, fill_value, *, dtype=None, device=None):
def full(size, fill_value, *, dtype=None, device=None, **kwargs):
new_size = ()
for s in size:
if isinstance(s, mindspore.Tensor):
Expand Down
4 changes: 2 additions & 2 deletions mindnlp/core/ops/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def clone(input):
# cumsum
has_cumsum = hasattr(mindspore.mint, "cumsum")


def cumsum(input, dim, dtype=None, out=None):
def cumsum(input, dim=None, dtype=None, out=None, **kwargs):
dim = kwargs.pop('axis', dim)
input_dtype = input.dtype
if input_dtype == mindspore.int64:
input = input.to(mindspore.int32)
Expand Down
Loading
Loading