Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions mindnlp/accelerate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import sys
import accelerate
from transformers.utils import _LazyModule

_import_structure = {
"utils": [
'DistributedType',

]
}

sys.modules[__name__] = _LazyModule(
'accelerate',
accelerate.__file__,
_import_structure,
module_spec=__spec__,
extra_objects={"__version__": accelerate.__version__},
)
25 changes: 24 additions & 1 deletion mindnlp/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,28 @@
from ._bind import get_default_dtype, set_default_dtype

from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \
return_types
return_types, linalg

from ._lowrank import svd_lowrank

def _has_compatible_shallow_copy_type(tensor, other):
"""
Mimics the behavior of mindtorch._has_compatible_shallow_copy_type.

Args:
tensor (mindtorch.Tensor): The source tensor.
other (mindtorch.Tensor): The target tensor to check compatibility.

Returns:
bool: True if `tensor` and `other` have compatible types for shallow copy.
"""
# Check if both tensors have the same type
if not is_tensor(tensor) or not is_tensor(other):
return False

# Check if both tensors are on the same device
if tensor.shape != other.shape:
return False

# Compatibility confirmed
return True
2 changes: 2 additions & 0 deletions mindnlp/core/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def is_floating_point(self):
return isinstance(self, (typing.Float, typing.BFloat16))

Type.is_floating_point = is_floating_point
Type.__str__ = Type.__repr__

half = float16
float = float32
Expand All @@ -19,6 +20,7 @@ def is_floating_point(self):
bool = bool_

float8_e4m3fn = None # TODO: not support fp8 for now
float8_e5m2 = None

np2dtype = {
np.bool_: bool,
Expand Down
3 changes: 0 additions & 3 deletions mindnlp/core/_dynamo/utils.py

This file was deleted.

4 changes: 2 additions & 2 deletions mindnlp/core/_lowrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Optional

from mindnlp import core
from mindnlp.core import _linalg_utils as _utils, Tensor
from core.overrides import handle_torch_function, has_torch_function
from . import _linalg_utils as _utils, Tensor
from .overrides import handle_torch_function, has_torch_function


def get_approximate_basis(
Expand Down
175 changes: 160 additions & 15 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
import math
import numpy as np
import mindspore
from mindspore import Tensor
from mindspore.common.tensor import _TensorMeta
from mindspore._c_expression.typing import Type
try:
from mindspore.common._stub_tensor import StubTensor
except:
class StubTensor: pass

from . import ops, _dtype
from ._dtype import dtype2np
from ._bind import get_default_device

from ._dtype import *
from ._bind import get_default_device, device_
from .configs import use_pyboost, ON_A1
from .storage import UntypedStorage

DTYPE_ELEMENT_SIZE_MAP = {
mindspore.float64: 8,
mindspore.int64: 8,
mindspore.int32: 4,
mindspore.float32: 4,
mindspore.int16: 2,
mindspore.bfloat16: 2,
mindspore.float16: 2,
}

class TypedTensorMeta(_TensorMeta):
def __isinstancecheck__(self, instance):
Expand All @@ -20,31 +32,31 @@ def __isinstancecheck__(self, instance):
return instance.dtype == self.dtype

class LongTensor(Tensor, metaclass=TypedTensorMeta):
dtype = long
dtype = _dtype.long
def __init__(self, data, device=None):
super().__init__(data, dtype=long)
super().__init__(data, dtype=_dtype.long)

class FloatTensor(Tensor, metaclass=TypedTensorMeta):
dtype = float32
dtype = _dtype.float32
def __init__(self, data, device=None):
super().__init__(data, dtype=float32)
super().__init__(data, dtype=_dtype.float32)


class HalfTensor(Tensor, metaclass=TypedTensorMeta):
dtype = float16
dtype = _dtype.float16
def __init__(self, data, device=None):
super().__init__(data, dtype=float16)
super().__init__(data, dtype=_dtype.float16)

class BFloat16Tensor(Tensor, metaclass=TypedTensorMeta):
dtype = float16
dtype = _dtype.float16
def __init__(self, data, device=None):
super().__init__(data, dtype=bfloat16)
super().__init__(data, dtype=_dtype.bfloat16)


class BoolTensor(Tensor, metaclass=TypedTensorMeta):
dtype = bool
dtype = _dtype.bool
def __init__(self, data, device=None):
super().__init__(data, dtype=bool)
super().__init__(data, dtype=_dtype.bool)

def tensor(data, *, dtype=None, device=None, requires_grad=False):
if isinstance(data, Tensor):
Expand Down Expand Up @@ -105,8 +117,8 @@ def data_ptr(self):
Tensor.data_ptr = data_ptr
StubTensor.data_ptr = data_ptr

Tensor.device = None
StubTensor.device = None
Tensor.device = device_('not support yet.')
StubTensor.device = device_('not support yet.')

def _expand(self, *size):
if len(size) == 1:
Expand All @@ -115,3 +127,136 @@ def _expand(self, *size):

Tensor.expand = _expand
StubTensor.expand = _expand

def clone(self, *args, **kwargs):
return self.copy()

Tensor.clone = clone
StubTensor.clone = clone

def _repeat(self, *sizes):
if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)):
sizes = sizes[0]
return ops.tile(self, tuple(sizes))

Tensor.repeat = _repeat
StubTensor.repeat = _repeat

def __or__(self, other):
if isinstance(other, (int, bool, float, Tensor)):
return ops.bitwise_or(self, other)
raise TypeError("Unsupported operand type(s) for |: 'Tensor' and '{}'".format(type(other)))

Tensor.__or__ = __or__
StubTensor.__or__ = __or__

def __and__(self, other):
if isinstance(other, (int, bool, float, Tensor)):
return ops.bitwise_and(self, other)
raise TypeError("Unsupported operand type(s) for &: 'Tensor' and '{}'".format(type(other)))

Tensor.__and__ = __and__
StubTensor.__and__ = __and__

def detach(self):
return ops.stop_gradient(self)

Tensor.detach = detach
StubTensor.detach = detach

origin_getitem = Tensor.__getitem__
def __getitem__(self, slices):
# if 0 in self.shape:
# return self
if isinstance(slices, tuple):
new_slices = ()
for s in slices:
if isinstance(s, range):
s = slice(s.start, s.stop, s.step)
new_slices += (s,)
slices = new_slices
return origin_getitem(self, slices)

Tensor.__getitem__ = __getitem__
StubTensor.__getitem__ = __getitem__

def numel(self):
return math.prod(self.shape)

Tensor.numel = numel
StubTensor.numel = numel
Tensor.nelement = numel
StubTensor.nelement = numel

@property
def nbytes(self):
return self.numel() * self.element_size()

Tensor.nbytes = nbytes
StubTensor.nbytes = nbytes

Tensor.normal_ = ops.inplace_normal
StubTensor.normal_ = ops.inplace_normal


Tensor.softmax = ops.softmax
StubTensor.softmax = ops.softmax

Tensor.squeeze = ops.squeeze
StubTensor.squeeze = ops.squeeze

Tensor.unsqueeze = ops.unsqueeze
StubTensor.unsqueeze = ops.unsqueeze

def log_softmax(self, dim=-1, dtype=None):
if use_pyboost():
return mindspore.mint.nn.functional.log_softmax(self, dim=dim, dtype=dtype)
out = mindspore.ops.log_softmax(self, dim)
if dtype is not None:
out = out.to(dtype)
return out

Tensor.log_softmax = log_softmax
StubTensor.log_softmax = log_softmax

def untyped_storage(self):
return UntypedStorage(self)

Tensor.untyped_storage = untyped_storage
StubTensor.untyped_storage = untyped_storage

def element_size(self,):
return DTYPE_ELEMENT_SIZE_MAP[self.dtype]

Tensor.element_size = element_size
StubTensor.element_size = element_size

@property
def layout(self):
return None

Tensor.layout = layout
StubTensor.layout = layout

def __add__(self, other):
# if 0 in self.shape:
# return self
return ops.add(self, other)

Tensor.__add__ = __add__
StubTensor.__add__ = __add__

Tensor.repeat_interleave = ops.repeat_interleave
StubTensor.repeat_interleave = ops.repeat_interleave

def dim(self):
return self.ndim

Tensor.dim = dim
StubTensor.dim = dim

def unfold(self, dimension, size, step):
return ops.unfold(self, dimension, size, step)

Tensor.unfold = unfold
StubTensor.unfold = unfold
2 changes: 1 addition & 1 deletion mindnlp/core/autograd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""autograd"""
from .node import Node
from .function import Function
from .function import Function, value_and_grad
from .grad_mode import no_grad, enable_grad, inference_mode
Loading
Loading