diff --git a/mindnlp/core/__init__.py b/mindnlp/core/__init__.py index 1b4dfe125..cf140038b 100644 --- a/mindnlp/core/__init__.py +++ b/mindnlp/core/__init__.py @@ -47,7 +47,7 @@ from .amp import autocast, GradScaler from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \ - return_types, linalg, fx, backends, testing, nn, fft + return_types, linalg, fx, backends, testing, nn, fft, _jit_internal, utils from ._lowrank import svd_lowrank from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state diff --git a/mindnlp/core/_dynamo/utils.py b/mindnlp/core/_dynamo/utils.py new file mode 100644 index 000000000..97594c287 --- /dev/null +++ b/mindnlp/core/_dynamo/utils.py @@ -0,0 +1,3 @@ +def is_compile_supported(device_type): + return False + diff --git a/mindnlp/core/_jit_internal.py b/mindnlp/core/_jit_internal.py index 6f2ed46d4..824a20089 100644 --- a/mindnlp/core/_jit_internal.py +++ b/mindnlp/core/_jit_internal.py @@ -1,3 +1,18 @@ +from typing import ( # noqa: UP035, F401 # (Dict, List, Tuple) imported by torch.jit.annotations + Any, + Callable, + Dict, + Final, + ForwardRef, + get_args, + get_origin, + List, + Optional, + Tuple, + TypeVar, + Union, +) + class FunctionModifiers: """ Used to denote the behavior of a function in TorchScript. See export() and diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 5c4412d18..00757723b 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -463,6 +463,11 @@ def clamp_min(self, value): Tensor.unsqueeze_ = ops.inplace_unsqueeze StubTensor.unsqueeze_ = ops.inplace_unsqueeze + def pin_memory(self, *args, **kwargs): + return self + + Tensor.pin_memory = pin_memory + StubTensor.pin_memory = pin_memory def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) diff --git a/mindnlp/core/jit/__init__.py b/mindnlp/core/jit/__init__.py index 9b66f51a3..8f978d79c 100644 --- a/mindnlp/core/jit/__init__.py +++ b/mindnlp/core/jit/__init__.py @@ -9,7 +9,7 @@ # _overload, # _overload_method, # export, - # Final, + Final, # Future, # ignore, # is_scripting, diff --git a/mindnlp/core/jit/annotations.py b/mindnlp/core/jit/annotations.py index b67c8a807..fca778d6d 100644 --- a/mindnlp/core/jit/annotations.py +++ b/mindnlp/core/jit/annotations.py @@ -1,4 +1,4 @@ -from core._jit_internal import ( # type: ignore[attr-defined] +from .._jit_internal import ( # type: ignore[attr-defined] # _Await, # _qualified_name, # Any, @@ -15,7 +15,7 @@ # is_optional, # is_tuple, # is_union, - # List, + List, # Optional, # Tuple, # Union, diff --git a/mindnlp/core/linalg/__init__.py b/mindnlp/core/linalg/__init__.py index 25891956c..5b56215c4 100644 --- a/mindnlp/core/linalg/__init__.py +++ b/mindnlp/core/linalg/__init__.py @@ -1,5 +1,5 @@ from collections import namedtuple -from mindspore import ops +from mindspore import ops, mint from mindspore.ops._primitive_cache import _get_cache_prim from mindnlp import core @@ -21,3 +21,5 @@ def cholesky_ex(A, *, upper=False, check_errors=False, out=None): return linalg_cholesky_ex(out, info) +def norm(A, ord=None, dim=None, keepdim=False, *, out=None, dtype=None): + return mint.norm(A, ord, dim, keepdim, dtype=dtype) diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index 657895647..dda7fc233 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -162,6 +162,11 @@ def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, coun return ops.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) +def adaptive_avg_pool1d(input, output_size): + if use_pyboost(): + return mint.nn.functional.adaptive_avg_pool1d(input, output_size) + return ops.adaptive_avg_pool1d(input, output_size) + def adaptive_avg_pool2d(input, output_size): if use_pyboost(): return mint.nn.functional.adaptive_avg_pool2d(input, output_size) @@ -1206,7 +1211,7 @@ def _none_or_dtype(input: Optional[core.Tensor]) -> Optional[int]: raise RuntimeError("input to _none_or_dtype() must be None or core.Tensor") def unfold(input, kernel_size, dilation=1, padding=0, stride=1): - if use_pyboost(): + if use_pyboost() and not ON_A1: return mint.nn.functional.unfold(input, kernel_size, dilation, padding, stride) return ops.unfold(input, kernel_size, dilation, padding, stride) diff --git a/mindnlp/core/nn/modules/__init__.py b/mindnlp/core/nn/modules/__init__.py index 4611cbbcf..6a37fb9a4 100644 --- a/mindnlp/core/nn/modules/__init__.py +++ b/mindnlp/core/nn/modules/__init__.py @@ -8,7 +8,7 @@ from .activation import * from .conv import Conv3d, Conv2d, Conv1d, ConvTranspose2d, ConvTranspose1d from .padding import ZeroPad2d, ConstantPad2d, ConstantPad1d, ConstantPad3d -from .batchnorm import BatchNorm2d, BatchNorm1d +from .batchnorm import BatchNorm2d, BatchNorm1d, SyncBatchNorm from .pooling import AdaptiveAvgPool2d, AvgPool1d, MaxPool2d, MaxPool1d, AdaptiveAvgPool1d, AvgPool2d from .flatten import Unflatten, Flatten from .rnn_cell import RNNCell, GRUCell, LSTMCell diff --git a/mindnlp/core/nn/modules/batchnorm.py b/mindnlp/core/nn/modules/batchnorm.py index c2a403d31..17b88d945 100644 --- a/mindnlp/core/nn/modules/batchnorm.py +++ b/mindnlp/core/nn/modules/batchnorm.py @@ -1,5 +1,5 @@ """batch norm""" -from typing import Optional +from typing import Optional, Any from mindnlp.core import Tensor from mindnlp import core from ..parameter import Parameter @@ -371,3 +371,289 @@ class BatchNorm3d(_BatchNorm): def _check_input_dim(self, input): if input.dim() != 5: raise ValueError(f"expected 5D input (got {input.dim()}D input)") + +class SyncBatchNorm(_BatchNorm): + r"""Applies Batch Normalization over a N-Dimensional input. + + The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over all + mini-batches of the same process groups. :math:`\gamma` and :math:`\beta` + are learnable parameter vectors of size `C` (where `C` is the input size). + By default, the elements of :math:`\gamma` are sampled from + :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done for each channel in the ``C`` dimension, computing + statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch + Normalization or Spatio-temporal Batch Normalization. + + Currently :class:`SyncBatchNorm` only supports + :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use + :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert + :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping + Network with DDP. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, +)` + eps: a value added to the denominator for numerical stability. + Default: ``1e-5`` + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + process_group: synchronization of stats happen within each process group + individually. Default behavior is synchronization across the whole + world + + Shape: + - Input: :math:`(N, C, +)` + - Output: :math:`(N, C, +)` (same shape as input) + + .. note:: + Synchronization of batchnorm statistics occurs only while training, i.e. + synchronization is disabled when ``model.eval()`` is set or if + ``self.training`` is otherwise ``False``. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With Learnable Parameters + >>> m = nn.SyncBatchNorm(100) + >>> # creating process group (optional) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) + >>> input = torch.randn(20, 100, 35, 45, 10) + >>> output = m(input) + + >>> # network is nn.BatchNorm layer + >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group) + >>> # only single gpu per process is currently supported + >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( + >>> sync_bn_network, + >>> device_ids=[args.local_rank], + >>> output_device=args.local_rank) + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: Optional[float] = 0.1, + affine: bool = True, + track_running_stats: bool = True, + process_group: Optional[Any] = None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.process_group = process_group + + def _check_input_dim(self, input): + if input.dim() < 2: + raise ValueError(f"expected at least 2D input (got {input.dim()}D input)") + + def _check_non_zero_input_channels(self, input): + if input.size(1) == 0: + raise ValueError( + "SyncBatchNorm number of input channels should be non-zero" + ) + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + self._check_non_zero_input_channels(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + assert self.num_batches_tracked is not None + self.num_batches_tracked.add_(1) + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / self.num_batches_tracked.item() + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + # If buffers are not to be tracked, ensure that they won't be updated + running_mean = ( + self.running_mean if not self.training or self.track_running_stats else None + ) + running_var = ( + self.running_var if not self.training or self.track_running_stats else None + ) + + # Don't sync batchnorm stats in inference mode (model.eval()). + need_sync = ( + bn_training + and self.training + and torch.distributed.is_available() + and torch.distributed.is_initialized() + ) + if need_sync: + # currently only GPU/PrivateUse1 input is supported + if input.device.type not in [ + "cuda", + "xpu", + torch._C._get_privateuse1_backend_name(), + ]: + raise ValueError( + "SyncBatchNorm expected input tensor to be on GPU or XPU or " + f"{torch._C._get_privateuse1_backend_name()}" + ) + + process_group = torch.distributed.group.WORLD + if self.process_group: + process_group = self.process_group + world_size = torch.distributed.get_world_size(process_group) + need_sync = world_size > 1 + + # fallback to framework BN when synchronization is not necessary + if not need_sync: + return F.batch_norm( + input, + running_mean, + running_var, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) + else: + assert bn_training + return sync_batch_norm.apply( + input, + self.weight, + self.bias, + running_mean, + running_var, + self.eps, + exponential_average_factor, + process_group, # type: ignore[possibly-undefined] + world_size, # type: ignore[possibly-undefined] + ) + + @classmethod + def convert_sync_batchnorm(cls, module, process_group=None): + r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers. + + Args: + module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers + process_group (optional): process group to scope synchronization, + default is the whole world + + Returns: + The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm` + layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer, + a new :class:`torch.nn.SyncBatchNorm` layer object will be returned + instead. + + Example:: + + >>> # Network with nn.BatchNorm layer + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> module = torch.nn.Sequential( + >>> torch.nn.Linear(20, 100), + >>> torch.nn.BatchNorm1d(100), + >>> ).cuda() + >>> # creating process group (optional) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> # xdoctest: +SKIP("distributed") + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] + >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group) + + """ + module_output = module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + module_output = torch.nn.SyncBatchNorm( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group, + ) + if module.affine: + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + module_output.training = module.training + if hasattr(module, "qconfig"): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module( + name, cls.convert_sync_batchnorm(child, process_group) + ) + del module + return module_output \ No newline at end of file diff --git a/mindnlp/core/nn/modules/instancenorm.py b/mindnlp/core/nn/modules/instancenorm.py index 390d3c33a..c6b8f8c5c 100644 --- a/mindnlp/core/nn/modules/instancenorm.py +++ b/mindnlp/core/nn/modules/instancenorm.py @@ -2,8 +2,8 @@ import warnings -from mindnlp import core.nn.functional as F from mindnlp.core import Tensor +from .. import functional as F from .batchnorm import _NormBase diff --git a/mindnlp/core/nn/modules/pooling.py b/mindnlp/core/nn/modules/pooling.py index e39f85b57..29bdef712 100644 --- a/mindnlp/core/nn/modules/pooling.py +++ b/mindnlp/core/nn/modules/pooling.py @@ -313,16 +313,8 @@ class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd): output_size: _size_1_t def forward(self, input: Tensor) -> Tensor: - # Add a dimension to make it 2D - input_2d = input.unsqueeze(2) + return F.adaptive_avg_pool1d(input, self.output_size) - # Perform adaptive average pooling - output_2d = ops.adaptive_avg_pool2d(input_2d, (self.output_size, 1)) - - # Remove the added dimension to make it back to 1D - output_1d = output_2d.squeeze(2) - - return output_1d class _AvgPoolNd(Module): __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad'] diff --git a/mindnlp/core/nn/utils/rnn.py b/mindnlp/core/nn/utils/rnn.py new file mode 100644 index 000000000..c1dee4373 --- /dev/null +++ b/mindnlp/core/nn/utils/rnn.py @@ -0,0 +1,120 @@ +import torch +from typing import Iterable, List, NamedTuple, Tuple, Union +import numpy as np +from torch import Tensor +from ..._dtype import np2dtype, dtype2np + + +def _pad_sequence(sequences: List[torch.Tensor], batch_first: bool = True, padding_value: float = 0.0, padding_side: str = "right") -> torch.Tensor: + """ + Pads a list of variable-length sequences to the same length using NumPy, mimicking PyTorch's pad_sequence. + + Args: + sequences (List[torch.Tensor]): List of sequences (Tensors) to pad. + batch_first (bool): If True, output shape will be (batch_size, max_len, *dims). + If False, output shape will be (max_len, batch_size, *dims). + padding_value (float): The value used for padding. + padding_side (str): Either 'left' or 'right', specifying where padding is applied. + + Returns: + torch.Tensor: A tensor with padded sequences. + """ + # Ensure valid padding_side input + assert padding_side in ["left", "right"], "padding_side must be 'left' or 'right'" + + # Get the size of the sequences list + sequences_size = len(sequences) + + # Get the max length of the sequences + max_len = max([seq.size(0) for seq in sequences]) + + # Get the trailing dimensions (if any) + trailing_dims = sequences[0].size()[1:] + + # Create the padded tensor with the padding_value + if batch_first: + out_dims = (sequences_size, max_len) + trailing_dims + else: + out_dims = (max_len, sequences_size) + trailing_dims + + # Use the dtype of the first sequence to ensure consistency + dtype = sequences[0].dtype + dtype = dtype2np[dtype] + out = np.full(out_dims, padding_value, dtype=dtype) # Use the same dtype as input + + # Pad the sequences + for i, seq in enumerate(sequences): + length_i = seq.size(0) + start = max_len - length_i if padding_side == "left" else 0 + + # Convert to NumPy array and handle padding + np_out = out[i] if batch_first else out[:, i] + np_seq = seq.numpy() # Convert to NumPy array + + if batch_first: + np_out[start:start + length_i] = np_seq + else: + np_out[start:start + length_i] = np_seq + + # Convert the NumPy result back to a PyTorch tensor + return torch.tensor(out, dtype=np2dtype[dtype]) # Ensure the output tensor has the same dtype as input + +def pad_sequence( + sequences: Union[Tensor, List[Tensor]], + batch_first: bool = False, + padding_value: float = 0.0, +) -> Tensor: + r"""Pad a list of variable length Tensors with ``padding_value`` + + ``pad_sequence`` stacks a list of Tensors along a new dimension, + and pads them to equal length. For example, if the input is a list of + sequences with size ``L x *`` and ``batch_first`` is False, the output is + of size ``T x B x *``. + + `B` is batch size. It is equal to the number of elements in ``sequences``. + `T` is length of the longest sequence. + `L` is length of the sequence. + `*` is any number of trailing dimensions, including none. + + Example: + >>> from torch.nn.utils.rnn import pad_sequence + >>> a = torch.ones(25, 300) + >>> b = torch.ones(22, 300) + >>> c = torch.ones(15, 300) + >>> pad_sequence([a, b, c]).size() + torch.Size([25, 3, 300]) + + Note: + This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` + where `T` is the length of the longest sequence. This function assumes + trailing dimensions and type of all the Tensors in sequences are same. + + Args: + sequences (list[Tensor]): list of variable length sequences. + batch_first (bool, optional): output will be in ``B x T x *`` if True, or in + ``T x B x *`` otherwise. Default: False. + padding_value (float, optional): value for padded elements. Default: 0. + + Returns: + Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. + Tensor of size ``B x T x *`` otherwise + """ + + if not (torch.jit.is_tracing() or torch.jit.is_scripting()): + # JIT doesn't support `Iterable` + if not isinstance(sequences, Iterable): + msg = ('pad_sequence: Expected iterable for input sequences, but got arg of type: ' + f'{type(sequences)}') + raise RuntimeError(msg) + + # In JIT context this leads to, + # RuntimeError: cannot statically infer the expected size of a list in this context + sequences = tuple(sequences) + else: + # For JIT, we only support Union[Tensor, Tuple[Tensor]] + if isinstance(sequences, torch.Tensor): + sequences = sequences.unbind(0) + + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + return _pad_sequence(sequences, batch_first, padding_value) diff --git a/mindnlp/core/ops/comparison.py b/mindnlp/core/ops/comparison.py index 3ab9b7558..4a032aa6f 100644 --- a/mindnlp/core/ops/comparison.py +++ b/mindnlp/core/ops/comparison.py @@ -73,7 +73,8 @@ def isfinite(input): # isin def isin(elements, test_elements): elements = elements.ravel().expand_dims(-1) - test_elements = test_elements.ravel() + if isinstance(test_elements, mindspore.Tensor): + test_elements = test_elements.ravel() included = ops.equal(elements, test_elements) # F.reduce_sum only supports float res = ops.sum(included.int(), -1).astype(mindspore.bool_) diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py index b87e66e0b..f32346b91 100644 --- a/mindnlp/core/ops/creation.py +++ b/mindnlp/core/ops/creation.py @@ -13,6 +13,7 @@ from ..configs import use_pyboost, ON_ORANGE_PI from .._bind import get_default_dtype, get_default_device from .utils import py2dtype +from .other import finfo def as_strided(self, size, stride, storage_offset=None): if len(size) != len(stride): @@ -69,7 +70,12 @@ def zeros_like(input, *, dtype=None, memory_format=None, **kwargs): # ones _ones = ops.Ones() has_ones = hasattr(mindspore.mint, 'ones') -def ones(*size, dtype=None, device=None): +def ones(*size, dtype=None, device=None, **kwargs): + if len(size) == 0: + size = kwargs.get('size', None) + if size == () or size == []: + size = ((),) + if isinstance(size[0], (tuple, list)): size = size[0] if dtype is None: @@ -247,8 +253,15 @@ def frombuffer(buffer, *, dtype=None, count=-1, offset=0, requires_grad=False): return mindspore.Tensor(output, dtype=dtype) +def scalar_tensor(value, dtype, device=None): + if value == float("-inf"): + value = finfo(dtype).min + if value == float("inf"): + value = finfo(dtype).max + return mindspore.Tensor(value, dtype=dtype) + __all__ = ['arange', 'as_strided', 'complex', 'empty', 'empty_like', 'eye', 'from_numpy', 'full', 'full_like', 'frombuffer', 'heaviside', 'linspace', 'logspace', 'ones', 'ones_like', - 'polar', 'range', 'zeros', 'zeros_like' + 'polar', 'range', 'zeros', 'zeros_like', 'scalar_tensor' ] \ No newline at end of file diff --git a/mindnlp/core/ops/fft_op.py b/mindnlp/core/ops/fft_op.py index 9ea151c05..e69de29bb 100644 --- a/mindnlp/core/ops/fft_op.py +++ b/mindnlp/core/ops/fft_op.py @@ -1,38 +0,0 @@ -"""fft""" -from mindspore import ops -from mindspore.ops._primitive_cache import _get_cache_prim -from ..configs import use_pyboost -from .array import narrow -from ._inner import pad - -def rfft(input, n=None, dim=-1, norm="backward"): - if use_pyboost(): - return ops.rfft(input, n, dim, norm) - if input.shape[dim] < n: - pad_inf = (0, n - input.shape[dim]) - pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf - input = pad(input, pad_dims) - else: - input = narrow(input, dim, 0, n) - _rfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, False, True, norm) - return _rfft(input) - -def irfft(input, n=None, dim=-1, norm="backward"): - if use_pyboost(): - return ops.irfft(input, n, dim, norm) - if input.shape[dim] < n: - pad_inf = (0, n - input.shape[dim]) - pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf - input = pad(input, pad_dims) - else: - input = narrow(input, dim, 0, n) - _irfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, True, True, norm) - return _irfft(input) - -def fftn(input, s=None, dim=None, norm=None): - return ops.fftn(input, s, dim, norm) - -def fft(input, s=None, dim=-1, norm=None): - return ops.fft(input, s, dim, norm) - -__all__ = ['fft', 'fftn', 'irfft', 'rfft'] \ No newline at end of file diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index 7c02b9d44..72e7b0549 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -111,6 +111,8 @@ def bucketize(input, boundaries, *, out_int32=False, right=False, out=None): def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): + if isinstance(p, int): + p = float(p) if use_pyboost() and has_cdist: return mindspore.mint.cdist(x1, x2, p, compute_mode) return ops.cdist(x1, x2, float(p)) @@ -696,8 +698,10 @@ def histc(input, bins, min, max, *, out=None): def meshgrid(*tensors, indexing=None): + if isinstance(tensors[0], (tuple, list)): + tensors = tensors[0] if use_pyboost() and has_meshgrid: - return mindspore.mint.meshgrid(*tensors, indexing) + return mindspore.mint.meshgrid(*tensors, indexing=indexing) if isinstance(tensors[0], (list, tuple)): tensors = tensors[0] if len(tensors) == 1: diff --git a/mindnlp/core/ops/random.py b/mindnlp/core/ops/random.py index b5d594e31..4b09277df 100644 --- a/mindnlp/core/ops/random.py +++ b/mindnlp/core/ops/random.py @@ -89,6 +89,14 @@ def rand_like(input, *, dtype=None): has_randint = hasattr(mindspore.mint, 'randint') def randint(*args, **kwargs): device = kwargs.pop('device', None) + high = kwargs.pop('high', None) + size = kwargs.pop('size', None) + if high is not None: + args += (high,) + + if size is not None: + args += (size,) + if use_pyboost() and has_randint: return mindspore.mint.randint(*args, **kwargs) return ops.randint(*args, **kwargs)