diff --git a/mindnlp/core/_C/__init__.py b/mindnlp/core/_C/__init__.py index 5ffd108e1..8e1034a08 100644 --- a/mindnlp/core/_C/__init__.py +++ b/mindnlp/core/_C/__init__.py @@ -1,4 +1,4 @@ -from mindspore import default_generator, Generator as msGenerator +from mindspore import Generator as msGenerator from . import _nn from ..types import device as device_ @@ -36,6 +36,10 @@ def __init__(self, device='cpu'): @property def device(self): - return self._device + if hasattr(self, '_device'): + return self._device + return device_('cpu') + +default_generator = Generator() class Tag: pass diff --git a/mindnlp/core/_dtype.py b/mindnlp/core/_dtype.py index ba5e312ae..9a979a871 100644 --- a/mindnlp/core/_dtype.py +++ b/mindnlp/core/_dtype.py @@ -1,8 +1,15 @@ +import warnings import numpy as np from mindspore.common.dtype import * from mindspore._c_expression import typing from mindspore._c_expression.typing import Type +from .configs import ON_A1 + +if ON_A1: + warnings.warn('910A do not support bfloat16, use float16 instead.') + bfloat16 = float16 + dtype = Type def is_floating_point(self): diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 68b6639ff..646af3318 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -137,9 +137,6 @@ def to_(self, *args, **kwargs): else: dtype_to = kwargs.get("dtype", None) if dtype_to is not None: - if ON_A1 and dtype_to == _dtype.bfloat16: - warnings.warn('910A do not support bfloat16, use float16 instead.') - return mindspore.ops.cast(self, _dtype.float16) return mindspore.ops.cast(self, dtype_to) return self @@ -592,9 +589,6 @@ def real(self): StubTensor.real = real def bfloat16(self): - if ON_A1: - warnings.warn('910A do not support bfloat16, use float16 instead.') - return self.to(_dtype.float16) return self.to(_dtype.bfloat16) Tensor.bfloat16 = bfloat16 @@ -639,6 +633,9 @@ def __contains__(self, item): Tensor.unflatten = ops.unflatten StubTensor.unflatten = ops.unflatten + Tensor.round_ = ops.inplace_round + StubTensor.round_ = ops.inplace_round + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) return ret \ No newline at end of file diff --git a/mindnlp/core/amp/autocast_mode.py b/mindnlp/core/amp/autocast_mode.py index ff2888c48..5716167f4 100644 --- a/mindnlp/core/amp/autocast_mode.py +++ b/mindnlp/core/amp/autocast_mode.py @@ -82,11 +82,38 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[ov def __call__(self, func): return autocast_decorator(self, func) +def _cast(value, device_type: str, dtype): + if isinstance(value, core.Tensor): + is_eligible = ( + value.is_floating_point() + and value.device.type == device_type + and (value.dtype is not core.float64) + ) + return value.to(dtype) if is_eligible else value + elif isinstance(value, (str, bytes)): + return value + elif HAS_NUMPY and isinstance(value, np.ndarray): + return value + elif isinstance(value, collections.abc.Mapping): + return { + _cast(k, device_type, dtype): _cast(v, device_type, dtype) + for k, v in value.items() + } + elif isinstance(value, collections.abc.Iterable): + iterable = (_cast(v, device_type, dtype) for v in value) + if isinstance(value, (list, tuple)): + return type(value)(iterable) + else: + return iterable + else: + return value + + def custom_fwd( fwd=None, *, device_type: str, - cast_inputs: Optional[_dtype] = None, + cast_inputs = None, ): """ Create a helper decorator for ``forward`` methods of custom autograd functions. diff --git a/mindnlp/core/distributions/__init__.py b/mindnlp/core/distributions/__init__.py index 3ca74fcd0..1d320163f 100644 --- a/mindnlp/core/distributions/__init__.py +++ b/mindnlp/core/distributions/__init__.py @@ -10,4 +10,5 @@ from .transforms import * from .relaxed_categorical import * from .relaxed_bernoulli import * -from .multivariate_normal import * \ No newline at end of file +from .multivariate_normal import * +from .gumbel import * \ No newline at end of file diff --git a/mindnlp/core/distributions/gumbel.py b/mindnlp/core/distributions/gumbel.py new file mode 100644 index 000000000..8a9416c12 --- /dev/null +++ b/mindnlp/core/distributions/gumbel.py @@ -0,0 +1,85 @@ +# mypy: allow-untyped-defs +import math + +from mindnlp import core +from mindnlp.core import Tensor +from . import constraints +from .transformed_distribution import TransformedDistribution +from .transforms import AffineTransform, ExpTransform +from .uniform import Uniform +from .utils import broadcast_all, euler_constant +from mindnlp.core.types import _Number + + +__all__ = ["Gumbel"] + + +class Gumbel(TransformedDistribution): + r""" + Samples from a Gumbel Distribution. + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0])) + >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2 + tensor([ 1.0124]) + + Args: + loc (float or Tensor): Location parameter of the distribution + scale (float or Tensor): Scale parameter of the distribution + """ + + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + + def __init__(self, loc, scale, validate_args=None): + self.loc, self.scale = broadcast_all(loc, scale) + finfo = core.finfo(self.loc.dtype) + if isinstance(loc, _Number) and isinstance(scale, _Number): + base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args) + else: + base_dist = Uniform( + core.full_like(self.loc, finfo.tiny), + core.full_like(self.loc, 1 - finfo.eps), + validate_args=validate_args, + ) + transforms = [ + ExpTransform().inv, + AffineTransform(loc=0, scale=-core.ones_like(self.scale)), + ExpTransform().inv, + AffineTransform(loc=loc, scale=-self.scale), + ] + super().__init__(base_dist, transforms, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Gumbel, _instance) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + return super().expand(batch_shape, _instance=new) + + # Explicitly defining the log probability function for Gumbel due to precision issues + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + y = (self.loc - value) / self.scale + return (y - y.exp()) - self.scale.log() + + @property + def mean(self) -> Tensor: + return self.loc + self.scale * euler_constant + + @property + def mode(self) -> Tensor: + return self.loc + + @property + def stddev(self) -> Tensor: + return (math.pi / math.sqrt(6)) * self.scale + + @property + def variance(self) -> Tensor: + return self.stddev.pow(2) + + def entropy(self): + return self.scale.log() + (1 + euler_constant) \ No newline at end of file diff --git a/mindnlp/core/distributions/uniform.py b/mindnlp/core/distributions/uniform.py new file mode 100644 index 000000000..b87936d8b --- /dev/null +++ b/mindnlp/core/distributions/uniform.py @@ -0,0 +1,101 @@ +# mypy: allow-untyped-defs +from mindnlp import core +from mindnlp.core import nan, Tensor +from . import constraints +from .distribution import Distribution +from .utils import broadcast_all +from mindnlp.core.types import _Number, _size + + +__all__ = ["Uniform"] + + +class Uniform(Distribution): + r""" + Generates uniformly distributed random samples from the half-open interval + ``[low, high)``. + + Example:: + + >>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0])) + >>> m.sample() # uniformly distributed in the range [0.0, 5.0) + >>> # xdoctest: +SKIP + tensor([ 2.3418]) + + Args: + low (float or Tensor): lower range (inclusive). + high (float or Tensor): upper range (exclusive). + """ + + # TODO allow (loc,scale) parameterization to allow independent constraints. + arg_constraints = { + "low": constraints.dependent(is_discrete=False, event_dim=0), + "high": constraints.dependent(is_discrete=False, event_dim=0), + } + has_rsample = True + + @property + def mean(self) -> Tensor: + return (self.high + self.low) / 2 + + @property + def mode(self) -> Tensor: + return nan * self.high + + @property + def stddev(self) -> Tensor: + return (self.high - self.low) / 12**0.5 + + @property + def variance(self) -> Tensor: + return (self.high - self.low).pow(2) / 12 + + def __init__(self, low, high, validate_args=None): + self.low, self.high = broadcast_all(low, high) + + if isinstance(low, _Number) and isinstance(high, _Number): + batch_shape = core.Size() + else: + batch_shape = self.low.size() + super().__init__(batch_shape, validate_args=validate_args) + + if self._validate_args and not core.lt(self.low, self.high).all(): + raise ValueError("Uniform is not defined when low>= high") + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Uniform, _instance) + batch_shape = core.Size(batch_shape) + new.low = self.low.expand(batch_shape) + new.high = self.high.expand(batch_shape) + super(Uniform, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return constraints.interval(self.low, self.high) + + def rsample(self, sample_shape: _size = core.Size()) -> Tensor: + shape = self._extended_shape(sample_shape) + rand = core.rand(shape, dtype=self.low.dtype, device=self.low.device) + return self.low + rand * (self.high - self.low) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + lb = self.low.le(value).type_as(self.low) + ub = self.high.gt(value).type_as(self.low) + return core.log(lb.mul(ub)) - core.log(self.high - self.low) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + result = (value - self.low) / (self.high - self.low) + return result.clamp(min=0, max=1) + + def icdf(self, value): + result = value * (self.high - self.low) + self.low + return result + + def entropy(self): + return core.log(self.high - self.low) \ No newline at end of file diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index 9dfc8c8c9..3815910eb 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -184,7 +184,7 @@ def adaptive_avg_pool2d(input, output_size): return ops.adaptive_avg_pool2d(input, output_size) def dropout(input, p=0.5, training=True): - if not training or p == 0: + if not training or p == 0 or 0 in input.shape: return input if use_pyboost() and not ON_ORANGE_PI: return mint.nn.functional.dropout(input, p, training) @@ -548,6 +548,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne scale_factors = scale_factor else: scale_factors = [scale_factor for _ in range(dim)] + scale_factors = [float(scale_factor) for scale_factor in scale_factors] else: raise ValueError("either size or scale_factor should be defined") @@ -562,7 +563,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne # "area" mode always requires an explicit size rather than scale factor. # Re-use the recompute_scale_factor code path. - if mode in ["area", "bilinear", "bicubic"] and output_size is None: + if mode in ["area", "bilinear", "bicubic", "nearest-exact"] and output_size is None: recompute_scale_factor = True if recompute_scale_factor is not None and recompute_scale_factor: @@ -595,7 +596,10 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne if input.dim() == 3 and mode == "nearest-exact": return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) if input.dim() == 4 and mode == "nearest-exact": - return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) + nearest_exact = _get_cache_prim(ops.ResizeNearestNeighborV2)( + align_corners=False, + half_pixel_centers=True) + return nearest_exact(input, output_size) if input.dim() == 5 and mode == "nearest-exact": return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) @@ -814,7 +818,34 @@ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_paddi return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation) def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): - return mint.nn.functional.conv_transpose3d(input, weight, bias, stride, padding, output_padding, groups, dilation) + in_channel, out_channel = weight.shape[0], weight.shape[1] + kernel_size = weight.shape[2:] + conv_transpose3d_op = ops.Conv3DTranspose( + in_channel, + out_channel, + kernel_size, + mode=1, + pad_mode='valid', + pad=padding, + stride=stride, + dilation=dilation, + group=1, + output_padding=output_padding, + data_format="NCDHW" + ) + if groups > 1: + outputs = () + for i in range(groups): + output = conv_transpose3d_op(input.half(), weight.half()) + if bias is not None: + output = output + bias + outputs = outputs + (output,) + out = ops.concat(outputs, 1) + else: + out = conv_transpose3d_op(input, weight) + if bias is not None: + out = out + bias + return out def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): diff --git a/mindnlp/core/nn/modules/normalization.py b/mindnlp/core/nn/modules/normalization.py index ed3928f9d..400ee250f 100644 --- a/mindnlp/core/nn/modules/normalization.py +++ b/mindnlp/core/nn/modules/normalization.py @@ -64,8 +64,8 @@ class LayerNorm(Module): .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 """ - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, bias: bool = True,dtype=None): - factory_kwargs = {'dtype': dtype} + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, bias: bool = True, dtype=None, device=None): + factory_kwargs = {'dtype': dtype, 'device': device} super(LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) diff --git a/mindnlp/core/ops/array.py b/mindnlp/core/ops/array.py index 65d312222..b5ebb219e 100644 --- a/mindnlp/core/ops/array.py +++ b/mindnlp/core/ops/array.py @@ -403,9 +403,12 @@ def where(condition, *args, out=None): return nonzero(condition, as_tuple=True) assert len(args) == 2 input, other = args + if isinstance(input, float) and input == -float("inf"): input = finfo(other.dtype).min if isinstance(other, float) and other == -float("inf"): + if isinstance(input, numbers.Number): + input = mindspore.tensor(input, dtype=mindspore.float32) other = finfo(input.dtype).min output = mindspore.mint.where(condition, input, other) diff --git a/mindnlp/core/ops/inplace.py b/mindnlp/core/ops/inplace.py index 396f414a7..111f92465 100644 --- a/mindnlp/core/ops/inplace.py +++ b/mindnlp/core/ops/inplace.py @@ -127,6 +127,11 @@ def inplace_triu(input, diagonal=0): input.assign_value(out) return input +def inplace_round(input, decimals=0): + out = ops.round(input, decimals=decimals) + input.assign_value(out) + return input + __all__ = [ @@ -142,5 +147,6 @@ def inplace_triu(input, diagonal=0): 'inplace_squeeze', 'inplace_unsqueeze', 'inplace_fill_diagonal', - 'inplace_triu' + 'inplace_triu', + 'inplace_round' ] diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index 7ff0bddd7..d02c2bd18 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -737,7 +737,7 @@ def meshgrid(*tensors, indexing=None): # repeat_interleave has_repeat_interleave = hasattr(mindspore.mint, 'repeat_interleave') -def repeat_interleave(input, repeats, dim=None): +def repeat_interleave(input, repeats, dim=None, *, output_size=None): if use_pyboost() and has_repeat_interleave and not ON_A1: return mindspore.mint.repeat_interleave(input, repeats, dim=dim) diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py index 7f91ca59a..71b502f3a 100644 --- a/mindnlp/core/ops/reduction.py +++ b/mindnlp/core/ops/reduction.py @@ -74,6 +74,8 @@ def max(*args, **kwargs): out = kwargs.pop('out', None) if 'dim' in kwargs and 'keepdim' not in kwargs: kwargs['keepdim'] = False + if 'axis' in kwargs: + kwargs['dim'] = kwargs.pop('axis') out = mindspore.mint.max(*args, **kwargs) if isinstance(out, tuple): return max_out(values=out[0], indices=out[1]) diff --git a/mindnlp/core/random.py b/mindnlp/core/random.py index e9d35838a..bfa2c17bb 100644 --- a/mindnlp/core/random.py +++ b/mindnlp/core/random.py @@ -4,7 +4,9 @@ from typing import Generator from mindnlp import core -from mindspore import default_generator, set_seed +# from mindspore import default_generator, set_seed + +default_generator = core._C.default_generator def get_rng_state(): """