Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mindnlp/core/_C/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mindspore import default_generator, Generator as msGenerator
from mindspore import Generator as msGenerator

from . import _nn
from ..types import device as device_
Expand Down Expand Up @@ -36,6 +36,10 @@ def __init__(self, device='cpu'):

@property
def device(self):
return self._device
if hasattr(self, '_device'):
return self._device
return device_('cpu')

default_generator = Generator()

class Tag: pass
7 changes: 7 additions & 0 deletions mindnlp/core/_dtype.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import warnings
import numpy as np
from mindspore.common.dtype import *
from mindspore._c_expression import typing
from mindspore._c_expression.typing import Type

from .configs import ON_A1

if ON_A1:
warnings.warn('910A do not support bfloat16, use float16 instead.')
bfloat16 = float16

dtype = Type

def is_floating_point(self):
Expand Down
9 changes: 3 additions & 6 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ def to_(self, *args, **kwargs):
else:
dtype_to = kwargs.get("dtype", None)
if dtype_to is not None:
if ON_A1 and dtype_to == _dtype.bfloat16:
warnings.warn('910A do not support bfloat16, use float16 instead.')
return mindspore.ops.cast(self, _dtype.float16)
return mindspore.ops.cast(self, dtype_to)
return self

Expand Down Expand Up @@ -592,9 +589,6 @@ def real(self):
StubTensor.real = real

def bfloat16(self):
if ON_A1:
warnings.warn('910A do not support bfloat16, use float16 instead.')
return self.to(_dtype.float16)
return self.to(_dtype.bfloat16)

Tensor.bfloat16 = bfloat16
Expand Down Expand Up @@ -639,6 +633,9 @@ def __contains__(self, item):
Tensor.unflatten = ops.unflatten
StubTensor.unflatten = ops.unflatten

Tensor.round_ = ops.inplace_round
StubTensor.round_ = ops.inplace_round

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
29 changes: 28 additions & 1 deletion mindnlp/core/amp/autocast_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,38 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[ov
def __call__(self, func):
return autocast_decorator(self, func)

def _cast(value, device_type: str, dtype):
if isinstance(value, core.Tensor):
is_eligible = (
value.is_floating_point()
and value.device.type == device_type
and (value.dtype is not core.float64)
)
return value.to(dtype) if is_eligible else value
elif isinstance(value, (str, bytes)):
return value
elif HAS_NUMPY and isinstance(value, np.ndarray):
return value
elif isinstance(value, collections.abc.Mapping):
return {
_cast(k, device_type, dtype): _cast(v, device_type, dtype)
for k, v in value.items()
}
elif isinstance(value, collections.abc.Iterable):
iterable = (_cast(v, device_type, dtype) for v in value)
if isinstance(value, (list, tuple)):
return type(value)(iterable)
else:
return iterable
else:
return value


def custom_fwd(
fwd=None,
*,
device_type: str,
cast_inputs: Optional[_dtype] = None,
cast_inputs = None,
):
"""
Create a helper decorator for ``forward`` methods of custom autograd functions.
Expand Down
3 changes: 2 additions & 1 deletion mindnlp/core/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
from .transforms import *
from .relaxed_categorical import *
from .relaxed_bernoulli import *
from .multivariate_normal import *
from .multivariate_normal import *
from .gumbel import *
85 changes: 85 additions & 0 deletions mindnlp/core/distributions/gumbel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# mypy: allow-untyped-defs
import math

from mindnlp import core
from mindnlp.core import Tensor
from . import constraints
from .transformed_distribution import TransformedDistribution
from .transforms import AffineTransform, ExpTransform
from .uniform import Uniform
from .utils import broadcast_all, euler_constant
from mindnlp.core.types import _Number


__all__ = ["Gumbel"]


class Gumbel(TransformedDistribution):
r"""
Samples from a Gumbel Distribution.

Examples::

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
>>> m.sample() # sample from Gumbel distribution with loc=1, scale=2
tensor([ 1.0124])

Args:
loc (float or Tensor): Location parameter of the distribution
scale (float or Tensor): Scale parameter of the distribution
"""

arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real

def __init__(self, loc, scale, validate_args=None):
self.loc, self.scale = broadcast_all(loc, scale)
finfo = core.finfo(self.loc.dtype)
if isinstance(loc, _Number) and isinstance(scale, _Number):
base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args)
else:
base_dist = Uniform(
core.full_like(self.loc, finfo.tiny),
core.full_like(self.loc, 1 - finfo.eps),
validate_args=validate_args,
)
transforms = [
ExpTransform().inv,
AffineTransform(loc=0, scale=-core.ones_like(self.scale)),
ExpTransform().inv,
AffineTransform(loc=loc, scale=-self.scale),
]
super().__init__(base_dist, transforms, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Gumbel, _instance)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
return super().expand(batch_shape, _instance=new)

# Explicitly defining the log probability function for Gumbel due to precision issues
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
y = (self.loc - value) / self.scale
return (y - y.exp()) - self.scale.log()

@property
def mean(self) -> Tensor:
return self.loc + self.scale * euler_constant

@property
def mode(self) -> Tensor:
return self.loc

@property
def stddev(self) -> Tensor:
return (math.pi / math.sqrt(6)) * self.scale

@property
def variance(self) -> Tensor:
return self.stddev.pow(2)

def entropy(self):
return self.scale.log() + (1 + euler_constant)
101 changes: 101 additions & 0 deletions mindnlp/core/distributions/uniform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# mypy: allow-untyped-defs
from mindnlp import core
from mindnlp.core import nan, Tensor
from . import constraints
from .distribution import Distribution
from .utils import broadcast_all
from mindnlp.core.types import _Number, _size


__all__ = ["Uniform"]


class Uniform(Distribution):
r"""
Generates uniformly distributed random samples from the half-open interval
``[low, high)``.

Example::

>>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
>>> m.sample() # uniformly distributed in the range [0.0, 5.0)
>>> # xdoctest: +SKIP
tensor([ 2.3418])

Args:
low (float or Tensor): lower range (inclusive).
high (float or Tensor): upper range (exclusive).
"""

# TODO allow (loc,scale) parameterization to allow independent constraints.
arg_constraints = {
"low": constraints.dependent(is_discrete=False, event_dim=0),
"high": constraints.dependent(is_discrete=False, event_dim=0),
}
has_rsample = True

@property
def mean(self) -> Tensor:
return (self.high + self.low) / 2

@property
def mode(self) -> Tensor:
return nan * self.high

@property
def stddev(self) -> Tensor:
return (self.high - self.low) / 12**0.5

@property
def variance(self) -> Tensor:
return (self.high - self.low).pow(2) / 12

def __init__(self, low, high, validate_args=None):
self.low, self.high = broadcast_all(low, high)

if isinstance(low, _Number) and isinstance(high, _Number):
batch_shape = core.Size()
else:
batch_shape = self.low.size()
super().__init__(batch_shape, validate_args=validate_args)

if self._validate_args and not core.lt(self.low, self.high).all():
raise ValueError("Uniform is not defined when low>= high")

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Uniform, _instance)
batch_shape = core.Size(batch_shape)
new.low = self.low.expand(batch_shape)
new.high = self.high.expand(batch_shape)
super(Uniform, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return constraints.interval(self.low, self.high)

def rsample(self, sample_shape: _size = core.Size()) -> Tensor:
shape = self._extended_shape(sample_shape)
rand = core.rand(shape, dtype=self.low.dtype, device=self.low.device)
return self.low + rand * (self.high - self.low)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
lb = self.low.le(value).type_as(self.low)
ub = self.high.gt(value).type_as(self.low)
return core.log(lb.mul(ub)) - core.log(self.high - self.low)

def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
result = (value - self.low) / (self.high - self.low)
return result.clamp(min=0, max=1)

def icdf(self, value):
result = value * (self.high - self.low) + self.low
return result

def entropy(self):
return core.log(self.high - self.low)
39 changes: 35 additions & 4 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def adaptive_avg_pool2d(input, output_size):
return ops.adaptive_avg_pool2d(input, output_size)

def dropout(input, p=0.5, training=True):
if not training or p == 0:
if not training or p == 0 or 0 in input.shape:
return input
if use_pyboost() and not ON_ORANGE_PI:
return mint.nn.functional.dropout(input, p, training)
Expand Down Expand Up @@ -548,6 +548,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
scale_factors = scale_factor
else:
scale_factors = [scale_factor for _ in range(dim)]
scale_factors = [float(scale_factor) for scale_factor in scale_factors]
else:
raise ValueError("either size or scale_factor should be defined")

Expand All @@ -562,7 +563,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne

# "area" mode always requires an explicit size rather than scale factor.
# Re-use the recompute_scale_factor code path.
if mode in ["area", "bilinear", "bicubic"] and output_size is None:
if mode in ["area", "bilinear", "bicubic", "nearest-exact"] and output_size is None:
recompute_scale_factor = True

if recompute_scale_factor is not None and recompute_scale_factor:
Expand Down Expand Up @@ -595,7 +596,10 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
if input.dim() == 3 and mode == "nearest-exact":
return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
if input.dim() == 4 and mode == "nearest-exact":
return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
nearest_exact = _get_cache_prim(ops.ResizeNearestNeighborV2)(
align_corners=False,
half_pixel_centers=True)
return nearest_exact(input, output_size)
if input.dim() == 5 and mode == "nearest-exact":
return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)

Expand Down Expand Up @@ -814,7 +818,34 @@ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_paddi
return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation)

def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
return mint.nn.functional.conv_transpose3d(input, weight, bias, stride, padding, output_padding, groups, dilation)
in_channel, out_channel = weight.shape[0], weight.shape[1]
kernel_size = weight.shape[2:]
conv_transpose3d_op = ops.Conv3DTranspose(
in_channel,
out_channel,
kernel_size,
mode=1,
pad_mode='valid',
pad=padding,
stride=stride,
dilation=dilation,
group=1,
output_padding=output_padding,
data_format="NCDHW"
)
if groups > 1:
outputs = ()
for i in range(groups):
output = conv_transpose3d_op(input.half(), weight.half())
if bias is not None:
output = output + bias
outputs = outputs + (output,)
out = ops.concat(outputs, 1)
else:
out = conv_transpose3d_op(input, weight)
if bias is not None:
out = out + bias
return out


def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
Expand Down
4 changes: 2 additions & 2 deletions mindnlp/core/nn/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class LayerNorm(Module):

.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, bias: bool = True,dtype=None):
factory_kwargs = {'dtype': dtype}
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, bias: bool = True, dtype=None, device=None):
factory_kwargs = {'dtype': dtype, 'device': device}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
Expand Down
Loading
Loading