Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mindspore.common.tensor import _TensorMeta
from mindspore._c_expression.typing import Type
try:
from mindspore.common._stub_tensor import StubTensor
from mindspore.common._stub_tensor import StubTensor, _stub_method
except:
class StubTensor: pass

Expand All @@ -17,7 +17,7 @@ class StubTensor: pass

from . import ops, _dtype
from ._dtype import dtype2np
from ._bind import get_default_device, device_
from ._bind import get_default_device, device_, get_default_dtype
from .configs import use_pyboost, ON_A1
from .storage import UntypedStorage
from ._utils import _rebuild_tensor_v2
Expand Down Expand Up @@ -98,6 +98,16 @@ def is_tensor(x):
return isinstance(x, Tensor)

def enable_mindspore_patch():
old_init = Tensor.__init__
def __init__(self, *args, **kwargs):
if len(args) > 1 and all([isinstance(arg, int) for arg in args]):
tensor = Tensor_(shape=args, dtype=get_default_dtype())
old_init(self, tensor, internal=True)
else:
old_init(self, *args, **kwargs)

Tensor.__init__ = __init__

def __reduce_ex__(self, protocol):
if isinstance(self, StubTensor):
data = Tensor_(self.stub_sync())
Expand Down Expand Up @@ -280,6 +290,8 @@ def __setitem__(self, slices, value):
# s = list(s)
# new_slices += (s,)
# slices = new_slices
if not isinstance(value, Tensor):
value = tensor(value, dtype=self.dtype)
return origin_setitem(self, slices, value)

Tensor.__setitem__ = __setitem__
Expand Down Expand Up @@ -469,6 +481,36 @@ def pin_memory(self, *args, **kwargs):
Tensor.pin_memory = pin_memory
StubTensor.pin_memory = pin_memory

def __deepcopy__(self, memodict):
new_obj = Tensor(self)
return new_obj

Tensor.__deepcopy__ = __deepcopy__
StubTensor.__deepcopy__ = __deepcopy__

def asnumpy(self):
return Tensor_.asnumpy(self)

Tensor.asnumpy = asnumpy
StubTensor.asnumpy = _stub_method(asnumpy)

def backward(self, *args, **kwargs):
pass

Tensor.backward = backward
StubTensor.backward = backward

def __repr__(self):
Tensor_.data_sync(self, True)
return Tensor_.__repr__(self)

Tensor.__repr__ = __repr__
StubTensor.__repr__ = _stub_method(__repr__)


def detach_(self):
return ops.stop_gradient(self)

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
2 changes: 1 addition & 1 deletion mindnlp/core/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import cuda, mps
from . import cuda, mps, cudnn
15 changes: 15 additions & 0 deletions mindnlp/core/backends/cudnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from contextlib import contextmanager

@contextmanager
def flags(
enabled=False,
benchmark=False,
benchmark_limit=10,
deterministic=False,
allow_tf32=True,
fp32_precision="none",
):
try:
yield
finally:
pass
196 changes: 191 additions & 5 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
import mindspore
from mindspore import ops, mint
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.ops.auto_generate import (reflection_pad_1d_op, reflection_pad_2d_op, add_layernorm_v2_op,
reflection_pad_3d_op, # pylint: disable=W0611
replication_pad_1d_op, replication_pad_2d_op, replication_pad_3d_op,
constant_pad_nd_op, dropout_ext_op, reverse_v2_impl, avg_pool2d_op,
upsample_nearest1d_op, upsample_nearest2d_op, upsample_nearest3d_op,
upsample_linear1d_op, upsample_bilinear2d_op, upsample_bicubic2d_op,
upsample_trilinear3d_impl, fill_scalar_op, floor_op, nllloss_2d_op,
masked_fill_op, masked_select, ones, flatten_ext, conv_transpose2d)



from mindnlp import core
from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost, ON_A1
Expand Down Expand Up @@ -243,7 +253,11 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, sca
return mint.nn.functional.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq)
return ops.gather(weight, input, 0)

def rms_norm(input, normalized_shape, weight, eps=1e-5):
def rms_norm(input, normalized_shape, weight, eps=None):
if eps is None:
eps = core.finfo(input.dtype).eps
if weight is None:
weight = core.ones(normalized_shape)
return ops.rms_norm(input, weight, eps)[0]

def fast_gelu(x):
Expand Down Expand Up @@ -463,7 +477,161 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
return _layer_norm(input, weight, bias)[0]

def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
return ops.interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor)
if mode in ("nearest", "area", "nearest-exact"):
if align_corners is not None:
raise ValueError(
"align_corners option can only be set with the "
"interpolating modes: linear | bilinear | bicubic | trilinear"
)
else:
if align_corners is None:
align_corners = False

dim = input.dim() - 2 # Number of spatial dimensions.

# Process size and scale_factor. Validate that exactly one is set.
# Validate its length if it is a list, or expand it if it is a scalar.
# After this block, exactly one of output_size and scale_factors will
# be non-None, and it will be a list (or tuple).
if size is not None and scale_factor is not None:
raise ValueError("only one of size or scale_factor should be defined")
elif size is not None:
assert scale_factor is None
scale_factors = None
if isinstance(size, (list, tuple)):
if len(size) != dim:
raise ValueError(
"Input and output must have the same number of spatial dimensions, but got "
f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. "
"Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
"output size in (o1, o2, ...,oK) format."
)
output_size = size
else:
output_size = [size for _ in range(dim)]
elif scale_factor is not None:
assert size is None
output_size = None
if isinstance(scale_factor, (list, tuple)):
if len(scale_factor) != dim:
raise ValueError(
"Input and scale_factor must have the same number of spatial dimensions, but "
f"got input with spatial dimensions of {list(input.shape[2:])} and "
f"scale_factor of shape {scale_factor}. "
"Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
"scale_factor in (s1, s2, ...,sK) format."
)
scale_factors = scale_factor
else:
scale_factors = [scale_factor for _ in range(dim)]
else:
raise ValueError("either size or scale_factor should be defined")

if (
recompute_scale_factor is not None
and recompute_scale_factor
and size is not None
):
raise ValueError(
"recompute_scale_factor is not meaningful with an explicit size."
)

# "area" mode always requires an explicit size rather than scale factor.
# Re-use the recompute_scale_factor code path.
if mode in ["area", "bilinear"] and output_size is None:
recompute_scale_factor = True

if recompute_scale_factor is not None and recompute_scale_factor:
# We compute output_size here, then un-set scale_factors.
# The C++ code will recompute it based on the (integer) output size.
assert scale_factors is not None
# make scale_factor a tensor in tracing so constant doesn't get baked in
output_size = [
(
math.floor(
float(input.size(i + 2) * scale_factors[i])
)
)
for i in range(dim)
]
scale_factors = None

if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4):
raise ValueError(
"Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input"
)

if input.dim() == 3 and mode == "nearest":
return upsample_nearest1d_op(input, output_size, scale_factors)
if input.dim() == 4 and mode == "nearest":
return upsample_nearest2d_op(input, output_size, scale_factors)
if input.dim() == 5 and mode == "nearest":
return upsample_nearest3d_op(input, output_size, scale_factors)

if input.dim() == 3 and mode == "nearest-exact":
return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
if input.dim() == 4 and mode == "nearest-exact":
return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
if input.dim() == 5 and mode == "nearest-exact":
return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)

if input.dim() == 3 and mode == "area":
assert output_size is not None
return adaptive_avg_pool1d(input, output_size)
if input.dim() == 4 and mode == "area":
assert output_size is not None
return adaptive_avg_pool2d(input, output_size)
if input.dim() == 5 and mode == "area":
assert output_size is not None
return adaptive_avg_pool3d(input, output_size)

if input.dim() == 3 and mode == "linear":
assert align_corners is not None
return upsample_linear1d_op(
input, output_size, scale_factors, align_corners
)
if input.dim() == 4 and mode == "bilinear":
assert align_corners is not None
if antialias:
return torch._C._nn._upsample_bilinear2d_aa(
input, output_size, align_corners, scale_factors
)
return upsample_bilinear2d_op(
input, output_size, scale_factors, align_corners
)
if input.dim() == 5 and mode == "trilinear":
assert align_corners is not None
return upsample_trilinear3d_impl(
input, output_size, scale_factors, align_corners
)
if input.dim() == 4 and mode == "bicubic":
assert align_corners is not None
if antialias:
return torch._C._nn._upsample_bicubic2d_aa(
input, output_size, align_corners, scale_factors
)
return upsample_bicubic2d_op(
input, output_size, scale_factors, align_corners
)

if input.dim() == 3 and mode == "bilinear":
raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input")
if input.dim() == 3 and mode == "trilinear":
raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input")
if input.dim() == 4 and mode == "linear":
raise NotImplementedError("Got 4D input, but linear mode needs 3D input")
if input.dim() == 4 and mode == "trilinear":
raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
if input.dim() == 5 and mode == "linear":
raise NotImplementedError("Got 5D input, but linear mode needs 3D input")
if input.dim() == 5 and mode == "bilinear":
raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")

raise NotImplementedError(
"Input Error: Only 3D, 4D and 5D input Tensors supported"
f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact"
f" (got {mode})"
)

def normalize(input, p=2.0, dim=1, eps=1e-6):
r"""
Expand Down Expand Up @@ -599,8 +767,24 @@ def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
raise ValueError("Requires mindspore >= 2.3.0 by default, or set into pyboost mode by calling torch.config.set_byboost(True).")

def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
return mint.nn.functional.conv_transpose1d(input, weight, bias, stride, padding, output_padding, groups, dilation)

x_2d = input.unsqueeze(2) # (batch, in_channels, 1, L_in)

# 2. 增加卷积核的高度维度
weight_2d = weight.unsqueeze(2) # (in_channels, out_channels, 1, kernel_size)

# 3. 二维转置卷积
output_2d = conv_transpose2d(
x_2d,
weight_2d,
bias,
stride=(1,) + stride,
padding=(0,) + padding,
output_padding=(0,) + output_padding,
dilation=(1,) + dilation
) # 输出形状: (batch, out_channels, 1, L_out)

# 4. 移除高度维度恢复一维
return output_2d.squeeze(2)

def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation)
Expand Down Expand Up @@ -1221,7 +1405,9 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
return ops.fold(input, output_size, kernel_size, dilation, padding, stride)

def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
ctc_loss_op = _get_cache_prim(nn_ops.CTCLossV2)(blank=blank, reduction="none", zero_infinity=zero_infinity)
ctc_loss_op = _get_cache_prim(ops.CTCLossV2)(blank=blank, reduction="none", zero_infinity=zero_infinity)
if targets.ndim == 1:
targets = targets.unsqueeze(-1)
loss, _ = ctc_loss_op(log_probs, targets, input_lengths, target_lengths)
if zero_infinity:
loss = ops.where(ops.isinf(loss), 0., loss)
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/core/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .container import ModuleList, ParameterList, Sequential, ParameterDict, ModuleDict
from .linear import Linear, Identity
from .sparse import Embedding
from .normalization import LayerNorm, GroupNorm
from .normalization import LayerNorm, GroupNorm, RMSNorm
from .dropout import Dropout, Dropout2d
from .activation import *
from .conv import Conv3d, Conv2d, Conv1d, ConvTranspose2d, ConvTranspose1d
Expand Down
10 changes: 8 additions & 2 deletions mindnlp/core/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,14 @@ def __init__(
self.register_buffer('running_var', ops.ones(num_features,))
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.register_buffer('num_batches_tracked',
Tensor(0, dtype=core.int64))
self.register_buffer(
"num_batches_tracked",
core.tensor(
0,
dtype=core.long,
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
),
)
self.num_batches_tracked: Optional[Tensor]
else:
self.register_buffer("running_mean", None)
Expand Down
Loading
Loading