Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import numpy as np
from functools import partial
import warnings
import mindspore
from mindspore import Tensor
from mindspore.common.tensor import _TensorMeta
Expand Down Expand Up @@ -379,6 +379,8 @@ def unfold(self, dimension, size, step):
StubTensor.unfold = unfold

def new(self, *shape):
if not isinstance(shape[0], int):
return tensor(shape[0], dtype=self.dtype)
return ops.empty(*shape, dtype=self.dtype)

Tensor.new = new
Expand Down Expand Up @@ -540,6 +542,21 @@ def new_tensor(self, data, *, dtype=None, device=None, requires_grad=False, layo
Tensor.triu_ = ops.inplace_triu
StubTensor.triu_ = ops.inplace_triu

@property
def real(self):
return ops.real(self)

Tensor.real = real
StubTensor.real = real

def bfloat16(self):
if ON_A1:
warnings.warn('910A do not support bfloat16, use float16 instead.')
return self.to(_dtype.float16)
return self.to(_dtype.bfloat16)

Tensor.bfloat16 = bfloat16
StubTensor.bfloat16 = bfloat16

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
Expand Down
1 change: 0 additions & 1 deletion mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, coun
Returns:
- numpy array: The result of the average pooling operation.
"""
print(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
if use_pyboost():
return mint.nn.functional.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)

Expand Down
Loading