Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
1 change: 0 additions & 1 deletion mindnlp/core/_functorch/apis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Callable

import mindspore

def vmap(
Expand Down
8 changes: 8 additions & 0 deletions mindnlp/core/_prims/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,11 @@ def cross(input, other, dim=None, *, out=None):
return pyboost_inner_prim.cross_impl(input, other, dim)

__all__.append('cross')

def logit(input, eps):
if eps is None:
eps = -1.0
logit_ = _get_cache_prim(ops.Logit)(eps).set_device('Ascend')
return logit_(input)

__all__.append('logit')
14 changes: 13 additions & 1 deletion mindnlp/core/_prims/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def getitem(input, slice):
__all__.append('getitem')

def sub_ext(input, other, alpha):
return input
if isinstance(input, core.Tensor):
return input
return other

__all__.append('sub_ext')

Expand Down Expand Up @@ -341,3 +343,13 @@ def reverse_v2(input, dims):
return input

__all__.append('reverse_v2')

def rsqrt(input):
return input

__all__.append('rsqrt')

def bitwise_xor_tensor(input, other):
return input

__all__.append('bitwise_xor_tensor')
66 changes: 62 additions & 4 deletions mindnlp/core/_prims/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,10 +753,13 @@ def std(input, dim, correction, keepdim):
__all__.append('std')

def meshgrid(tensors, indexing):
out = np.meshgrid(*[t.numpy() for t in tensors], indexing=indexing)
if not isinstance(out, np.ndarray):
out = np.array(out)
return core.Tensor.from_numpy(out)
outs = np.meshgrid(*[t.numpy() for t in tensors], indexing=indexing)
new_outs = ()
for out in outs:
if not isinstance(out, np.ndarray):
out = np.array(out)
new_outs += (core.Tensor.from_numpy(out),)
return new_outs

__all__.append('meshgrid')

Expand Down Expand Up @@ -809,3 +812,58 @@ def reverse_v2(input, dims):
return core.Tensor.from_numpy(out)

__all__.append('reverse_v2')

def rsqrt(input):
out = np.reciprocal(np.sqrt(input.numpy()))
if not isinstance(out, np.ndarray):
out = np.array(out)
return core.Tensor.from_numpy(out)

__all__.append('rsqrt')

def bitwise_xor_tensor(input, other):
out = np.bitwise_xor(input.numpy(), other.numpy())
return core.Tensor.from_numpy(out)

__all__.append('bitwise_xor_tensor')

def minimum(input, other):
out = np.minimum(input.numpy(), other.numpy())
return core.Tensor.from_numpy(out)

__all__.append('minimum')

def prod_ext(input, dim, keepdim, dtype):
out = np.prod(input.numpy(), axis=dim, keepdims=keepdim)
return core.Tensor.from_numpy(out)

__all__.append('prod_ext')

def select(condition, input, other):
if not isinstance(input, numbers.Number):
input = input.numpy()
if not isinstance(other, numbers.Number):
other = other.numpy()

out = np.where(condition.numpy(), input, other)
return core.Tensor.from_numpy(out)

__all__.append('select')

def dense(input, weight, bias):
output = np.dot(input.numpy(), weight.numpy().T)
if bias is not None:
output += bias
return core.Tensor.from_numpy(output)

__all__.append('dense')

def dropout_ext(input, p):
if p != 0:
mask = (np.random.rand(*input.shape) < (1 - p))
out = input.numpy() * mask / (1 - p)
return core.Tensor.from_numpy(out), core.Tensor.from_numpy(mask)
else:
return input, None

__all__.append('dropout_ext')
17 changes: 11 additions & 6 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def __len__(self):
return self.shape[0]

def __repr__(self) -> str:
self.data_sync(True)
# self.data_sync(True)
return Tensor_.__repr__(self)[:-1] + f', device={self.device})'

def __format__(self, format_spec):
Expand Down Expand Up @@ -982,8 +982,8 @@ def diagnoal(self, offset=0, dim1=0, dim2=1):


# Tensor.div
def div(self, other):
return ops.div(self, other)
def div(self, other, rounding_mode=None):
return ops.div(self, other, rounding_mode=rounding_mode)

# Tensor.div_
def div_(self, other):
Expand Down Expand Up @@ -1257,13 +1257,18 @@ def index_add_(self, dim, index, source, *, alpha=1):

# Tensor.index_add
def index_add(self, dim, index, source, *, alpha=1):
return ops.index_add(self, dim, source, alpha=alpha)
return ops.index_add(self, dim, index, source, alpha=alpha)

# Tensor.index_copy_

def index_copy_(self, dim, index, tensor2):
return self.copy_(self.index_copy(dim, index, tensor2))

# Tensor.index_copy

def index_copy(self, dim, index, tensor2):
original_values_at_index = self.index_select(dim, index)
result = self.index_add(dim, index, -original_values_at_index)
result.index_add_(dim, index, tensor2)
return result

# Tensor.index_fill_

Expand Down
7 changes: 5 additions & 2 deletions mindnlp/core/nn/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import contextlib

class SDPBackend:
pass
MATH = 0

@contextlib.contextmanager
def sdpa_kernel(*args, **kwargs):
pass
yield {}
47 changes: 47 additions & 0 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,9 +1590,56 @@ def pixel_shuffle(input, upscale_factor):
def pixel_unshuffle(input, downscale_factor):
return ops.pixel_unshuffle(input, downscale_factor)

def getWH(input):
"""Get [W, H] tensor from input"""
H, W = input.size()[-2:]
return core.tensor([[W, H]], dtype=core.float32, device=input.device)

def center_of(input):
"""return [(W-1)/2, (H-1)/2] tensor of input img"""
if input.dim() == 4:
H, W = input.size()[-2:]
shape = [[W, H]]
else:
D, H, W = input.size()[-3:]
shape = [[W, H, D]]
return core.tensor(shape, dtype=core.float32, device=input.device).sub_(1).div_(2)

def u(s, a: float = -0.75):
s2, s3 = s**2, s**3
l1 = (a+2)*s3 - (a+3)*s2 + 1
l2 = a*s3 - (5*a)*s2 + (8*a)*s - 4*a
return l1.where(s <= 1, l2)

def bicubic_grid_sample(input, grid, padding_mode: str = 'zeros', align_corners: bool = False):
"""bicubic_grid_sample"""
kernel_size = 4
if not align_corners:
grid = grid * getWH(input) / getWH(input).sub_(1)
center = center_of(input)
abs_loc = ((grid + 1) * center).unsqueeze(-1)

locs = abs_loc.floor() + core.tensor([-1, 0, 1, 2], device=grid.device)

loc_w, loc_h = locs.detach().flatten(0, 2).unbind(dim=-2)
loc_w = loc_w.reshape(-1, 1, kernel_size).expand(-1, kernel_size, -1)
loc_h = loc_h.reshape(-1, kernel_size, 1).expand(-1, -1, kernel_size)
loc_grid = core.stack([loc_w, loc_h], dim=-1)
loc_grid = loc_grid.view(grid.size(0), -1, 1, 2)/center - 1

selected = grid_sample(input, loc_grid.detach(), mode='nearest',
padding_mode=padding_mode, align_corners=True)
patch = selected.view(input.size()[:2]+grid.size()[1:3]+(kernel_size,)*2)

mat_r, mat_l = u(core.abs(abs_loc - locs.detach())).unbind(dim=-2)
output = core.einsum('bhwl,bchwlr,bhwr->bchw', mat_l, patch, mat_r)
return output

def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False):
align_corners = False if align_corners is None else align_corners
if input.ndim == 4:
if mode == 'bicubic':
return bicubic_grid_sample(input, grid, padding_mode, align_corners)
return execute('grid_sampler_2d', input, grid, mode, padding_mode, align_corners)
return execute('grid_sampler_3d', input, grid, mode, padding_mode, align_corners)

Expand Down
14 changes: 10 additions & 4 deletions mindnlp/core/ops/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def zeros(*size, out=None, dtype=None, layout=None, device=None, requires_grad=F

if isinstance(device, str):
device = core.device(device)
if isinstance(size[0], (tuple, list)):
if len(size) > 0 and isinstance(size[0], (tuple, list)):
size = size[0]

new_size = ()
for s in size:
if not isinstance(s, int):
Expand Down Expand Up @@ -139,6 +139,10 @@ def linspace(start, end, steps, *, out=None, dtype=None, layout=None, device=Non
if isinstance(device, str):
device = core.device(device)

start = start.item() if isinstance(start, (core.Tensor, np.integer)) else start
end = end.item() if isinstance(end, (core.Tensor, np.integer)) else end
steps = steps.item() if isinstance(steps, (core.Tensor, np.integer)) else steps

output = execute('lin_space_ext', start, end, steps, dtype,
device=device, requires_grad=requires_grad, user_created=True)
if out is None:
Expand All @@ -154,6 +158,8 @@ def eye(n, m=None, *, out=None, dtype=None, layout=None, device=None, requires_g
device = get_device_in_context()
if dtype is None:
dtype = get_default_dtype()
if m is None:
m = n
output = execute('eye', n, m, dtype,
device=device, requires_grad=requires_grad, user_created=True)
if out is None:
Expand Down Expand Up @@ -194,8 +200,8 @@ def empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=Fal

# full
def full(size, fill_value, *, out=None, dtype=None, layout=None, device=None, requires_grad=False):
if dtype is None:
dtype = get_default_dtype()
# if dtype is None:
# dtype = get_default_dtype()
if device is None:
device = get_device_in_context()
if device.type == 'cpu':
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/core/ops/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def nansum(input, dim=None, keepdim=False, *, dtype=None):

# prod
def prod(input, dim=None, keepdim=False, *, dtype=None):
return execute('prod_ext', input, dim, keepdim,dtype)
return execute('prod_ext', input, dim, keepdim, dtype)

# quantile

Expand Down
4 changes: 2 additions & 2 deletions mindnlp/core/special/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mindspore import ops
from ..executor import execute

def logit(input, eps=None, *, out=None):
return ops.logit(input, eps)
return execute('logit', input, eps)
Empty file removed mindnlp/factory/__init__.py
Empty file.
18 changes: 14 additions & 4 deletions mindnlp/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from packaging import version
from mindnlp.core.configs import ON_ORANGE_PI
from mindnlp.utils.import_utils import *
from mindnlp.utils.import_utils import _LazyModule
Expand Down Expand Up @@ -4102,12 +4103,13 @@


from . import ms_utils
from .masking_utils import create_causal_mask, create_sliding_window_causal_mask
from .masking_utils import create_causal_mask, create_sliding_window_causal_mask, create_masks_for_generate
from .modeling_utils import construct_pipeline_parallel_model, _load_pretrained_model_wrapper, \
_get_resolved_checkpoint_files_wrapper
from .tokenization_utils import apply_chat_template_wrapper
from .trainer import training_step
from .generation import *
from .models.gemma3 import token_type_ids_mask_function

# redirect mindnlp.transformers to transformers
import transformers
Expand All @@ -4134,9 +4136,14 @@ def empty_fn(*args, **kwargs):
from ..utils.decorators import dtype_wrapper, patch_dtype_wrapper, patch_wrappers

patch_dtype_wrapper(transformers.AutoModel, 'from_pretrained')
patch_dtype_wrapper(transformers.modeling_utils.PreTrainedModel, 'from_pretrained',
[transformers.modeling_utils.restore_default_torch_dtype]
)
if version.parse(transformers.__version__) >= version.parse('4.56.0'):
patch_dtype_wrapper(transformers.modeling_utils.PreTrainedModel, 'from_pretrained',
[transformers.modeling_utils.restore_default_dtype]
)
else:
patch_dtype_wrapper(transformers.modeling_utils.PreTrainedModel, 'from_pretrained',
[transformers.modeling_utils.restore_default_torch_dtype]
)
patch_wrappers(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model',
[_load_pretrained_model_wrapper])

Expand All @@ -4152,6 +4159,8 @@ def empty_fn(*args, **kwargs):
transformers.modeling_utils.caching_allocator_warmup = empty_fn
transformers.masking_utils.create_causal_mask = create_causal_mask
transformers.masking_utils.create_sliding_window_causal_mask = create_sliding_window_causal_mask
transformers.masking_utils.create_masks_for_generate = create_masks_for_generate
transformers.generation.utils.create_masks_for_generate = create_masks_for_generate

transformers.trainer.Trainer.training_step = training_step
# for ORANGE_PI
Expand All @@ -4160,3 +4169,4 @@ def empty_fn(*args, **kwargs):

# add mindnlp.transformers modules/attrs to lazymodule
# setattr(sys.modules[__name__], 'test_ms_model', test_ms_model)
transformers.models.gemma3.modeling_gemma3.token_type_ids_mask_function = token_type_ids_mask_function
Loading
Loading