Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,18 @@ def tobytes(self):
Tensor.erfinv_ = ops.inplace_erfinv
StubTensor.erfinv_ = ops.inplace_erfinv

def is_pinned(self):
return False

Tensor.is_pinned = is_pinned
StubTensor.is_pinned = is_pinned

def record_stream(self, stream):
pass

Tensor.record_stream = record_stream
StubTensor.record_stream = record_stream

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
2 changes: 1 addition & 1 deletion mindnlp/core/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

SOC = MSContext.get_instance().get_ascend_soc_version()
DEVICE_TARGET = mindspore.get_context('device_target')
SUPPORT_BF16 = SOC in ["ascend910b", "ascend910_93"]
SUPPORT_BF16 = DEVICE_TARGET == 'Ascend' and SOC not in ['ascend910', 'ascend310b1', 'ascend310b4']
ON_A1 = not SUPPORT_BF16
ON_ORANGE_PI = '310b' in SOC
USE_PYBOOST = DEVICE_TARGET == 'Ascend'
Expand Down
22 changes: 18 additions & 4 deletions mindnlp/core/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any
from typing import Any, Optional

import mindspore
from mindspore import get_rng_state, set_rng_state, manual_seed
from mindspore.hal import *
from mindspore.runtime import memory_reserved as ms_memory_reserved, \
memory_allocated as ms_memory_allocated
memory_allocated as ms_memory_allocated, StreamCtx as StreamContext, Stream, empty_cache, \
reset_peak_memory_stats, reset_max_memory_allocated, max_memory_allocated, synchronize, \
current_stream
from mindspore.device_context.gpu import device_count

from mindnlp import core

Expand Down Expand Up @@ -56,4 +58,16 @@ def memory_reserved(device=None):
return ms_memory_reserved()

def memory_allocated(device=None):
return ms_memory_allocated()
return ms_memory_allocated()

def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext:
r"""Wrap around the Context-manager StreamContext that selects a given stream.

Arguments:
stream (Stream): selected stream. This manager is a no-op if it's
``None``.
.. note::
In eager mode stream is of type Stream class while in JIT it is
an object of the custom class ``torch.classes.cuda.Stream``.
"""
return StreamContext(stream)
215 changes: 129 additions & 86 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,40 +119,7 @@ def avg_pool1d(input, kernel_size, stride, padding=0, ceil_mode=False, count_inc
if use_pyboost():
return mint.nn.functional.avg_pool1d(input, kernel_size, stride, padding, ceil_mode, count_include_pad)

N, C, L = input_array.shape

# Add padding to the input array
if padding > 0:
input_array = ops.pad(input_array, ((0, 0), (0, 0), (padding, padding)), mode='constant', value=(0, 0))

# Calculate the output length
if ceil_mode:
output_length = int(np.ceil((L + 2 * padding - pool_size) / stride).astype(int) + 1)
else:
output_length = int(np.floor((L + 2 * padding - pool_size) / stride).astype(int) + 1)

# Initialize the output array
output_array = ops.zeros((N, C, output_length))

# Generate the starting indices of the pooling windows
indices = ops.arange(output_length) * stride
indices = indices[:, None] + ops.arange(pool_size)

# Ensure indices are within bounds
indices = ops.minimum(indices, input_array.shape[2] - 1)

# Use advanced indexing to extract the pooling windows
windows = input_array[:, :, indices]

# Calculate the mean along the pooling window dimension
if count_include_pad:
output_array = ops.mean(windows, axis=-1)
else:
valid_counts = ops.sum(windows != 0, dim=-1)
valid_counts = ops.maximum(valid_counts, 1) # Avoid division by zero
output_array = ops.sum(windows, dim=-1) / valid_counts

return output_array
return ops.avg_pool1d(input, kernel_size, stride, padding, ceil_mode, count_include_pad)

def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None):
"""
Expand All @@ -172,6 +139,8 @@ def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, coun
if use_pyboost():
return mint.nn.functional.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)

if divisor_override is None:
divisor_override = 0
return ops.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)

has_avg_pool3d = hasattr(mint.nn.functional, 'avg_pool3d')
Expand Down Expand Up @@ -318,6 +287,16 @@ def custom_circular_pad(x, pad):
return x

def pad(input, pad, mode='constant', value=None):
if input.device.type != 'npu':
if mode == 'reflect' and input.ndim > 4:
paddings = [[0, 0]]
for i in range(0, len(pad), 2):
paddings.append([pad[i], pad[i+1]])
old_shape = input.shape
shape = (-1, *old_shape[-3:])
out = ops.MirrorPad()(input.reshape(shape), mindspore.Tensor(paddings))
return out.reshape(*old_shape[:-3], *out.shape[-3:])
return ops.pad(input, pad, mode, value)
if sum(pad) == 0:
return input
if isinstance(pad, tuple):
Expand Down Expand Up @@ -852,36 +831,34 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if use_pyboost():
return mint.nn.functional.conv3d(input, weight, bias, stride, padding, dilation, groups)
"""
pad_mode = 'pad'
pad = padding
if isinstance(padding, tuple):
pad = (padding[0], padding[0], padding[1], padding[1])
elif isinstance(padding, int):
pad = (padding,) * 6
if not isinstance(padding, (int, tuple)):
pad_mode = padding
pad = (0,) * 6

self.conv3d = mops.Conv3D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=pad_mode,
pad=pad,
stride=self.stride,
dilation=self.dilation,
group=self.groups)
if self.padding_mode != 'zeros':
input = ops.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode)
output = self.conv3d(input, self.weight)

if self.bias is not None:
output = mops.bias_add(output, self.bias)

pad_mode = 'pad'
pad = padding
if isinstance(padding, (tuple, list)):
pad = (padding[0], padding[0], padding[1], padding[1], padding[2], padding[2])
elif isinstance(padding, int):
pad = (padding,) * 6
if not isinstance(padding, (int, tuple, list)):
pad_mode = padding
pad = (0,) * 6

out_channels = weight.shape[0]
kernel_size = weight.shape[2:]
conv3d_op = ops.Conv3D(out_channels,
kernel_size,
mode=1,
pad_mode=pad_mode,
pad=pad,
stride=stride,
dilation=dilation,
group=groups)
output = conv3d_op(input, weight)
if bias is not None:
output = ops.bias_add(output, bias)
return output


"""
raise ValueError("Requires mindspore >= 2.3.0 by default, or set into pyboost mode by calling torch.config.set_byboost(True).")

def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
x_2d = input.unsqueeze(2) # (batch, in_channels, 1, L_in)
Expand All @@ -903,38 +880,104 @@ def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_paddi
# 4. 移除高度维度恢复一维
return output_2d.squeeze(2)

def _deconv_output_length(pad_mode, filter_size, stride_size, dilation_size, padding):
"""Calculate the width and height of output."""
length = 0
filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
if pad_mode == 'valid':
if filter_size - stride_size > 0:
length = filter_size - stride_size
elif pad_mode == 'pad':
length = - padding + filter_size - stride_size

return length

def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation)
if use_pyboost():
return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation)

def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
in_channel, out_channel = weight.shape[0], weight.shape[1]
pad_mode = 'pad'
pad = padding
if isinstance(padding, tuple):
pad = (0, 0, padding[0], padding[0])
elif isinstance(padding, int):
pad = (0, 0) + (padding,) * 2
if not isinstance(padding, (int, tuple)):
pad_mode = padding
pad = (0,) * 4

in_channel, out_channels = weight.shape[0], weight.shape[1] * groups
kernel_size = weight.shape[2:]
conv_transpose3d_op = ops.Conv3DTranspose(
in_channel,
out_channel,
kernel_size,
mode=1,
pad_mode='valid',
pad=padding,
stride=stride,
dilation=dilation,
group=1,
output_padding=output_padding,
data_format="NCDHW"
)
if groups > 1:
outputs = ()
for i in range(groups):
output = conv_transpose3d_op(input.half(), weight.half())

conv2d_transpose_op = ops.Conv2DTranspose(out_channel=out_channels,
kernel_size=kernel_size,
mode=1,
pad_mode=pad_mode,
pad=pad,
stride=stride,
dilation=dilation,
group=groups)
n, _, h, w = input.shape
h_add = _deconv_output_length(pad_mode, kernel_size[0], stride[0], dilation[0], pad[0] + pad[1])
w_add = _deconv_output_length(pad_mode, kernel_size[1], stride[1], dilation[1], pad[2] + pad[3])

out = conv2d_transpose_op(input, weight,
(n, out_channels, h * stride[0] + h_add, w * stride[1] + w_add))
if bias is not None:
out = ops.bias_add(out, bias)
return out

def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
if input.device.type == 'npu':
in_channel, out_channel = weight.shape[0], weight.shape[1]
kernel_size = weight.shape[2:]
conv_transpose3d_op = ops.Conv3DTranspose(
in_channel,
out_channel,
kernel_size,
mode=1,
pad_mode='valid',
pad=padding,
stride=stride,
dilation=dilation,
group=1,
output_padding=output_padding,
data_format="NCDHW"
)
if groups > 1:
outputs = ()
for i in range(groups):
output = conv_transpose3d_op(input.half(), weight.half())
if bias is not None:
output = output + bias
outputs = outputs + (output,)
out = ops.concat(outputs, 1)
else:
out = conv_transpose3d_op(input, weight)
if bias is not None:
output = output + bias
outputs = outputs + (output,)
out = ops.concat(outputs, 1)
out = out + bias
return out
else:
in_channel, out_channel = weight.shape[0], weight.shape[1] * groups
kernel_size = weight.shape[2:]
conv_transpose3d_op = ops.Conv3DTranspose(
in_channel,
out_channel,
kernel_size,
mode=1,
pad_mode='valid',
pad=padding,
stride=stride,
dilation=dilation,
group=groups,
output_padding=output_padding,
data_format="NCDHW"
)

out = conv_transpose3d_op(input, weight)
if bias is not None:
out = out + bias
return out
return out


def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
Expand Down
4 changes: 3 additions & 1 deletion mindnlp/core/ops/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ def empty(*size, dtype=None, device=None, requires_grad=False, pin_memory=False,
# empty_like
has_empty_like = hasattr(mindspore.mint, 'empty_like')
def empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None):
return mindspore.mint.empty_like(input, dtype=dtype, device=device)
if use_pyboost():
return mindspore.mint.empty_like(input, dtype=dtype, device=device)
return mindspore.Tensor(np.empty(input.shape, dtype=dtype2np[input.dtype]))

# empty_strided

Expand Down
9 changes: 8 additions & 1 deletion mindnlp/core/ops/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .pointwise import div, log
from .._bind import get_default_dtype
from ._inner import call_ms_func
from .._C import default_generator

# bernoulli
has_bernoulli = hasattr(mindspore.mint, 'bernoulli')
Expand Down Expand Up @@ -117,6 +118,8 @@ def randint_like(*args, **kwargs):
# randn
has_randn = hasattr(mindspore.mint, 'randn')
def randn(*size, generator=None, dtype=None, **kwargs):
if isinstance(size[0], tuple):
size = size[0]
size = kwargs.pop('size', size)
new_size = ()
for s in size:
Expand All @@ -128,7 +131,11 @@ def randn(*size, generator=None, dtype=None, **kwargs):
if use_pyboost() and has_randn:
return mindspore.mint.randn(*new_size, generator=generator, dtype=dtype)
# return ops.randn(*new_size, dtype=dtype)
return mindspore.Tensor(np.random.randn(*new_size), dtype=dtype)
if not generator:
generator = default_generator
seed, _ = generator._step(12)
rng = np.random.default_rng(seed.item())
return mindspore.Tensor(rng.standard_normal(new_size), dtype=dtype)

# randn_like
has_randn_like = hasattr(mindspore.mint, 'randn_like')
Expand Down
2 changes: 2 additions & 0 deletions mindnlp/core/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from typing import Generator

import mindspore
from mindnlp import core
# from mindspore import default_generator, set_seed

Expand Down Expand Up @@ -53,6 +54,7 @@ def manual_seed(seed):
is raised. Negative inputs are remapped to positive values with the formula
`0xffff_ffff_ffff_ffff + seed`.
"""
mindspore.set_seed(seed)
seed = int(seed)
# set_seed(seed)
return default_generator.manual_seed(seed)
Expand Down
Loading