Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mindnlp/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ._bind import get_default_dtype, set_default_dtype

from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \
return_types, linalg, fx
return_types, linalg, fx, backends, testing

from ._lowrank import svd_lowrank
from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state
Expand Down
22 changes: 22 additions & 0 deletions mindnlp/core/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ def is_floating_point(self):
Type.is_floating_point = is_floating_point
Type.__str__ = Type.__repr__

@property
def itemsize(self):
return ITEM_SIZE[self]

Type.itemsize = itemsize

half = float16
float = float32
double = float64
Expand All @@ -22,6 +28,22 @@ def is_floating_point(self):
float8_e4m3fn = None # TODO: not support fp8 for now
float8_e5m2 = None

ITEM_SIZE = {
bool : 1,
int8 : 1,
int16 : 2,
int32 : 4,
int64 : 8,
uint8 : 1,
uint16 : 2,
uint32 : 4,
uint64 : 8,
float16 : 2,
bfloat16 : 2,
float32 : 4,
float64 : 8,
}

np2dtype = {
np.bool_: bool,
np.int8: int8,
Expand Down
35 changes: 30 additions & 5 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def __getitem__(self, slices):
new_slices = ()
for s in slices:
if isinstance(s, range):
s = slice(s.start, s.stop, s.step)
s = list(s)
new_slices += (s,)
slices = new_slices
return origin_getitem(self, slices)
Expand Down Expand Up @@ -288,10 +288,8 @@ def unfold(self, dimension, size, step):
Tensor.unfold = unfold
StubTensor.unfold = unfold

def new(self, data=None):
if data is None:
return Tensor([], dtype=self.dtype)
return Tensor(data, dtype=self.dtype)
def new(self, *shape):
return ops.empty(*shape, dtype=self.dtype)

Tensor.new = new
StubTensor.new = new
Expand All @@ -310,6 +308,33 @@ def cpu(self):
Tensor.cpu = cpu
StubTensor.cpu = cpu

Tensor.take = ops.take
StubTensor.take = ops.take

Tensor.sort = ops.sort
StubTensor.sort = ops.sort

def requires_grad_(self, requires_grad=True):
self.requires_grad = requires_grad
return self

Tensor.requires_grad_ = requires_grad_
StubTensor.requires_grad_ = requires_grad_

@property
def data(self):
return Tensor(self)

@data.setter
def data(self, new_value):
if isinstance(self, StubTensor) and isinstance(new_value, StubTensor):
self.stub = new_value.stub
else:
self.assign_value(new_value)

Tensor.data = data
StubTensor.data = data

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
Empty file.
Empty file.
4 changes: 2 additions & 2 deletions mindnlp/core/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ def remove_from(*dicts_or_sets):
self.register_parameter(name, value)
else:
modules = self.__dict__.get("_modules")
if isinstance(value, core.nn.Module):
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call"
Expand Down Expand Up @@ -1565,7 +1565,7 @@ def get_submodule(self, target: str) -> "Module":

mod = getattr(mod, item)

if not isinstance(mod, core.nn.Module):
if not isinstance(mod, Module):
raise AttributeError("`" + item + "` is not "
"an nn.Module")

Expand Down
5 changes: 4 additions & 1 deletion mindnlp/core/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def split(tensor, split_size_or_sections, dim=0):

# squeeze
has_squeeze = hasattr(mindspore.mint, 'squeeze')
def squeeze(input, *dim):
def squeeze(input, *dim, **kwargs):
dim = kwargs.get('dim', dim)
if use_pyboost() and has_squeeze:
return mindspore.mint.squeeze(input, dim)
return ops.squeeze(input, dim)
Expand Down Expand Up @@ -255,6 +256,8 @@ def take(input, index):
index = index.view(-1)
if ON_ORANGE_PI:
return tf_gather(input, index, 0).view(index_shape)
if index_shape == ():
return gather(input, 0, index)[0]
return gather(input, 0, index).view(index_shape)

def infer_size_impl(a, b):
Expand Down
7 changes: 5 additions & 2 deletions mindnlp/core/ops/comparison.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""comparison op"""
from collections import namedtuple
import numpy as np
import mindspore
from mindspore import ops
from ..configs import use_pyboost

from ._inner import call_ms_func

sort_out = namedtuple('stor_out', ['sorted', 'indices'])
# allclose
has_allclose = hasattr(mindspore.mint, 'allclose')
def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
Expand Down Expand Up @@ -167,8 +169,9 @@ def not_equal(input, other):
has_sort = hasattr(mindspore.mint, 'sort')
def sort(input, *, dim=-1, descending=False, stable=False):
if use_pyboost() and has_sort:
return mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable)
return ops.sort(input, dim, descending)
out = mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable)
out = ops.sort(input, dim, descending)
return sort_out(sorted=out[0], indices=out[1])

# topk
has_topk = hasattr(mindspore.mint, 'topk')
Expand Down
1 change: 1 addition & 0 deletions mindnlp/core/ops/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def rand_like(input, *, dtype=None):
# randint
has_randint = hasattr(mindspore.mint, 'randint')
def randint(*args, **kwargs):
device = kwargs.pop('device', None)
if use_pyboost() and has_randint:
return mindspore.mint.randint(*args, **kwargs)
return ops.randint(*args, **kwargs)
Expand Down
7 changes: 6 additions & 1 deletion mindnlp/core/ops/reduction.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""reduction op"""
from collections import namedtuple
import mindspore
from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim
from ..configs import use_pyboost, DEVICE_TARGET

from ._inner import call_ms_func

max_out = namedtuple('max_out', ['values', 'indices'])
# argmax
has_argmax = hasattr(mindspore.mint, 'argmax')
def argmax(input, dim=None, keepdim=False):
Expand Down Expand Up @@ -67,7 +69,10 @@ def any(input, dim=None, keepdim=False, *, out=None):
# max
has_max = hasattr(mindspore.mint, 'max')
def max(*args, **kwargs):
return mindspore.mint.max(*args, **kwargs)
out = mindspore.mint.max(*args, **kwargs)
if isinstance(out, tuple):
return max_out(values=out[0], indices=out[1])
return out

# min
has_min = hasattr(mindspore.mint, 'min')
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/core/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def is_tensor_like(inp):
>>> is_tensor_like(TensorLike())
True
"""
return type(inp) is core.Tensor or hasattr(inp, "__torch_function__")
return isinstance(inp, core.Tensor) or hasattr(inp, "__torch_function__")

def handle_torch_function(
public_api: Callable,
Expand Down
1 change: 1 addition & 0 deletions mindnlp/core/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._comparison import assert_allclose, assert_close as assert_close
Loading
Loading