Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions mindnlp/core/_dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from .decorators import (
# allow_in_graph,
# assume_constant_result,
# disable,
# disallow_in_graph,
# dont_skip_tracing,
# forbid_in_graph,
# graph_break,
# mark_dynamic,
# mark_static,
mark_static_address,
# maybe_mark_dynamic,
# nonstrict_trace,
# patch_dynamo_config,
# run,
# set_fullgraph,
# set_stance,
# skip_frame,
# substitute_in_graph,
)

from . import eval_frame
17 changes: 17 additions & 0 deletions mindnlp/core/_dynamo/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Any, Callable, Optional, overload, TYPE_CHECKING, TypeVar, Union

def mark_static_address(t: Any, guard: bool = True) -> None:
"""
Marks an input tensor whose data_ptr will not change across multiple calls
to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation
is not needed for this input. The data_ptr will be guarded if guard=True. Note:
Tensors marked in this way will be kept alive until `torch._dynamo.reset()` is called.
"""
# if not isinstance(t, torch.Tensor):
# raise TypeError(f"mark_static_address expects a tensor but received {type(t)}")

# if guard:
# t._dynamo_static_input_type = "guarded" # type: ignore[attr-defined]
# else:
# t._dynamo_static_input_type = "unguarded" # type: ignore[attr-defined]
pass
2 changes: 2 additions & 0 deletions mindnlp/core/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class OptimizedModule:
pass
9 changes: 9 additions & 0 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def is_meta(self):
StubTensor.is_meta = is_meta

def data_ptr(self):
ptr = self._data_ptr()
if ptr != 0:
return ptr
self + 1
return self._data_ptr()

Tensor.data_ptr = data_ptr
Expand Down Expand Up @@ -439,9 +443,14 @@ def __rmul__(self, other):

def clamp_min(self, value):
return ops.clamp(self, value)

Tensor.clamp_min = clamp_min
StubTensor.clamp_min = clamp_min

Tensor.index_copy_ = ops.inplace_index_copy
StubTensor.index_copy_ = ops.inplace_index_copy


def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
13 changes: 11 additions & 2 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def avg_pool1d(input, kernel_size, stride, padding=0, ceil_mode=False, count_inc

return output_array

def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=0):
def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None):
"""
Perform 2D average pooling on the input array.

Expand All @@ -156,11 +156,17 @@ def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, coun
Returns:
- numpy array: The result of the average pooling operation.
"""
print(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
if use_pyboost():
return mindspore.ops.function.nn_func.avg_pool2d_ext(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
return mint.nn.functional.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)

return ops.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)

def adaptive_avg_pool2d(input, output_size):
if use_pyboost():
return mint.nn.functional.adaptive_avg_pool2d(input, output_size)
return ops.adaptive_avg_pool2d(input, output_size)

def dropout(input, p=0.5, training=True):
if not training or p == 0:
return input
Expand Down Expand Up @@ -300,6 +306,9 @@ def pad(input, pad, mode='constant', value=0.0):
new_pad += (pad_v,)
if sum(new_pad) == 0:
return input
if input.dtype == mindspore.bool_:
input = input.to(mindspore.int32)
return ops.pad(input, new_pad, mode, value).to(mindspore.bool_)
return ops.pad(input, new_pad, mode, value)

def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0):
Expand Down
1 change: 1 addition & 0 deletions mindnlp/core/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""batch norm"""
from typing import Optional
from mindnlp.core import Tensor
from mindnlp import core
from ..parameter import Parameter

from .module import Module
Expand Down
4 changes: 2 additions & 2 deletions mindnlp/core/nn/modules/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
output_size: _size_2_opt_t

def forward(self, input: Tensor) -> Tensor:
return ops.adaptive_avg_pool2d(input, self.output_size)
return F.adaptive_avg_pool2d(input, self.output_size)

class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
Expand Down Expand Up @@ -479,7 +479,7 @@ def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, p
self.padding = padding
self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad
self.divisor_override = divisor_override if divisor_override is not None else 0
self.divisor_override = divisor_override

def forward(self, input: Tensor) -> Tensor:
return F.avg_pool2d(input, self.kernel_size, self.stride,
Expand Down
5 changes: 0 additions & 5 deletions mindnlp/core/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ def index_add(input, dim, index, source, *, alpha=1):
return mindspore.mint.index_add(input, dim, index, source, alpha=alpha)
return ops.index_add(input, index, source, dim)

def inplace_index_add(input, dim, index, source):
_inplace = _get_cache_prim(ops.InplaceIndexAdd)(dim)
return _inplace(input, index, source)

# index_copy


Expand Down Expand Up @@ -748,7 +744,6 @@ def strided_slice_update(input, begin, end, strides, update, begin_mask=0, end_m
'hstack',
'index_fill',
'index_add',
'inplace_index_add',
# index_copy
# index_reduce
'index_select',
Expand Down
16 changes: 15 additions & 1 deletion mindnlp/core/ops/inplace.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import mindspore
from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.common.generator import default_generator
from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op

Expand Down Expand Up @@ -84,6 +87,15 @@ def inplace_scatter(input, dim, index, src):
return execute('inplace_scatter_value', input, dim, index, src)
return execute('inplace_scatter', input, dim, index, src)

def inplace_index_copy(input, dim, index, tensor):
selected = input.index_select(dim, index)
input.index_add_(dim, index, -selected)
input.index_add_(dim, index, tensor)
return input

def inplace_index_add(input, dim, index, source):
_inplace = _get_cache_prim(ops.InplaceIndexAdd)(dim)
return _inplace(input, index, source)

__all__ = [
'inplace_copy',
Expand All @@ -92,5 +104,7 @@ def inplace_scatter(input, dim, index, src):
'inplace_fill',
'inplace_uniform',
'inplace_add',
'inplace_scatter'
'inplace_scatter',
'inplace_index_copy',
'inplace_index_add'
]
12 changes: 12 additions & 0 deletions mindnlp/core/ops/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def broadcast_shapes(*shapes):


# bucketize
def bucketize(input, boundaries, *, out_int32=False, right=False, out=None):
if isinstance(boundaries, mindspore.Tensor):
boundaries = boundaries.tolist()
out = ops.bucketize(input, boundaries, right=right)
if not out_int32:
out = out.to(mindspore.int64)
return out

# cartesian_prod

Expand Down Expand Up @@ -733,6 +740,10 @@ def repeat_interleave(input, repeats, dim=None):
repeats = repeats[0]
if repeats == 0:
return Tensor_(input.dtype, (0,))
if input.dtype == mindspore.bool_:
input = input.to(mindspore.int32)
out = ops.repeat_elements(input, repeats, dim)
return out.to(mindspore.bool_)
return ops.repeat_elements(input, repeats, dim)
size = input.shape[dim]
if len(repeats) != size:
Expand Down Expand Up @@ -992,6 +1003,7 @@ def unfold(input, dimension, size, step):
"broadcast_shapes",
"broadcast_tensors",
"broadcast_to",
"bucketize",
"cdist",
"clone",
"contains",
Expand Down
10 changes: 5 additions & 5 deletions mindnlp/utils/safetensors_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def get(self, *args, **kwargs):
buffer = bytearray(nbytes)
self.bufferfile.seek(self.start_offset)
self.bufferfile.readinto(buffer)
tensor = np.frombuffer(buffer, dtype=self.dtype).reshape(self.shape)
tensor = tensor.reshape(self.shape)
array = np.frombuffer(buffer, dtype=self.dtype).reshape(self.shape)
array = array.reshape(self.shape)
if not SUPPORT_BF16 and self.info["dtype"] == 'BF16':
tensor = tensor.astype(np.float16)
tensor = Tensor.from_numpy(tensor)

array = array.astype(np.float16)
tensor = Tensor.from_numpy(array)
tensor._ptr = array.ctypes.data
return tensor

@property
Expand Down
Loading