From 65a4fdb857c7f93f300dedbd32917d2402df65e8 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Fri, 1 Aug 2025 10:19:54 +0000 Subject: [PATCH] fix diffusers autoencoder on GPU --- mindnlp/core/_tensor.py | 12 ++ mindnlp/core/configs.py | 2 +- mindnlp/core/cuda/__init__.py | 22 +++- mindnlp/core/nn/functional.py | 215 ++++++++++++++++++++-------------- mindnlp/core/ops/creation.py | 4 +- mindnlp/core/ops/random.py | 9 +- mindnlp/core/random.py | 2 + 7 files changed, 173 insertions(+), 93 deletions(-) diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 1c6dbddec..5933296cb 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -800,6 +800,18 @@ def tobytes(self): Tensor.erfinv_ = ops.inplace_erfinv StubTensor.erfinv_ = ops.inplace_erfinv + def is_pinned(self): + return False + + Tensor.is_pinned = is_pinned + StubTensor.is_pinned = is_pinned + + def record_stream(self, stream): + pass + + Tensor.record_stream = record_stream + StubTensor.record_stream = record_stream + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) return ret \ No newline at end of file diff --git a/mindnlp/core/configs.py b/mindnlp/core/configs.py index 6a65f786a..6543b55c1 100644 --- a/mindnlp/core/configs.py +++ b/mindnlp/core/configs.py @@ -4,7 +4,7 @@ SOC = MSContext.get_instance().get_ascend_soc_version() DEVICE_TARGET = mindspore.get_context('device_target') -SUPPORT_BF16 = SOC in ["ascend910b", "ascend910_93"] +SUPPORT_BF16 = DEVICE_TARGET == 'Ascend' and SOC not in ['ascend910', 'ascend310b1', 'ascend310b4'] ON_A1 = not SUPPORT_BF16 ON_ORANGE_PI = '310b' in SOC USE_PYBOOST = DEVICE_TARGET == 'Ascend' diff --git a/mindnlp/core/cuda/__init__.py b/mindnlp/core/cuda/__init__.py index a1da3c032..823d1bce1 100644 --- a/mindnlp/core/cuda/__init__.py +++ b/mindnlp/core/cuda/__init__.py @@ -1,10 +1,12 @@ -from typing import Any +from typing import Any, Optional import mindspore from mindspore import get_rng_state, set_rng_state, manual_seed -from mindspore.hal import * from mindspore.runtime import memory_reserved as ms_memory_reserved, \ - memory_allocated as ms_memory_allocated + memory_allocated as ms_memory_allocated, StreamCtx as StreamContext, Stream, empty_cache, \ + reset_peak_memory_stats, reset_max_memory_allocated, max_memory_allocated, synchronize, \ + current_stream +from mindspore.device_context.gpu import device_count from mindnlp import core @@ -56,4 +58,16 @@ def memory_reserved(device=None): return ms_memory_reserved() def memory_allocated(device=None): - return ms_memory_allocated() \ No newline at end of file + return ms_memory_allocated() + +def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext: + r"""Wrap around the Context-manager StreamContext that selects a given stream. + + Arguments: + stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: + In eager mode stream is of type Stream class while in JIT it is + an object of the custom class ``torch.classes.cuda.Stream``. + """ + return StreamContext(stream) \ No newline at end of file diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index 22b41035f..4409a7fa1 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -119,40 +119,7 @@ def avg_pool1d(input, kernel_size, stride, padding=0, ceil_mode=False, count_inc if use_pyboost(): return mint.nn.functional.avg_pool1d(input, kernel_size, stride, padding, ceil_mode, count_include_pad) - N, C, L = input_array.shape - - # Add padding to the input array - if padding > 0: - input_array = ops.pad(input_array, ((0, 0), (0, 0), (padding, padding)), mode='constant', value=(0, 0)) - - # Calculate the output length - if ceil_mode: - output_length = int(np.ceil((L + 2 * padding - pool_size) / stride).astype(int) + 1) - else: - output_length = int(np.floor((L + 2 * padding - pool_size) / stride).astype(int) + 1) - - # Initialize the output array - output_array = ops.zeros((N, C, output_length)) - - # Generate the starting indices of the pooling windows - indices = ops.arange(output_length) * stride - indices = indices[:, None] + ops.arange(pool_size) - - # Ensure indices are within bounds - indices = ops.minimum(indices, input_array.shape[2] - 1) - - # Use advanced indexing to extract the pooling windows - windows = input_array[:, :, indices] - - # Calculate the mean along the pooling window dimension - if count_include_pad: - output_array = ops.mean(windows, axis=-1) - else: - valid_counts = ops.sum(windows != 0, dim=-1) - valid_counts = ops.maximum(valid_counts, 1) # Avoid division by zero - output_array = ops.sum(windows, dim=-1) / valid_counts - - return output_array + return ops.avg_pool1d(input, kernel_size, stride, padding, ceil_mode, count_include_pad) def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None): """ @@ -172,6 +139,8 @@ def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, coun if use_pyboost(): return mint.nn.functional.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + if divisor_override is None: + divisor_override = 0 return ops.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) has_avg_pool3d = hasattr(mint.nn.functional, 'avg_pool3d') @@ -318,6 +287,16 @@ def custom_circular_pad(x, pad): return x def pad(input, pad, mode='constant', value=None): + if input.device.type != 'npu': + if mode == 'reflect' and input.ndim > 4: + paddings = [[0, 0]] + for i in range(0, len(pad), 2): + paddings.append([pad[i], pad[i+1]]) + old_shape = input.shape + shape = (-1, *old_shape[-3:]) + out = ops.MirrorPad()(input.reshape(shape), mindspore.Tensor(paddings)) + return out.reshape(*old_shape[:-3], *out.shape[-3:]) + return ops.pad(input, pad, mode, value) if sum(pad) == 0: return input if isinstance(pad, tuple): @@ -852,36 +831,34 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if use_pyboost(): return mint.nn.functional.conv3d(input, weight, bias, stride, padding, dilation, groups) - """ - pad_mode = 'pad' - pad = padding - if isinstance(padding, tuple): - pad = (padding[0], padding[0], padding[1], padding[1]) - elif isinstance(padding, int): - pad = (padding,) * 6 - if not isinstance(padding, (int, tuple)): - pad_mode = padding - pad = (0,) * 6 - - self.conv3d = mops.Conv3D(out_channel=self.out_channels, - kernel_size=self.kernel_size, - mode=1, - pad_mode=pad_mode, - pad=pad, - stride=self.stride, - dilation=self.dilation, - group=self.groups) - if self.padding_mode != 'zeros': - input = ops.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode) - output = self.conv3d(input, self.weight) - - - if self.bias is not None: - output = mops.bias_add(output, self.bias) + + pad_mode = 'pad' + pad = padding + if isinstance(padding, (tuple, list)): + pad = (padding[0], padding[0], padding[1], padding[1], padding[2], padding[2]) + elif isinstance(padding, int): + pad = (padding,) * 6 + if not isinstance(padding, (int, tuple, list)): + pad_mode = padding + pad = (0,) * 6 + + out_channels = weight.shape[0] + kernel_size = weight.shape[2:] + conv3d_op = ops.Conv3D(out_channels, + kernel_size, + mode=1, + pad_mode=pad_mode, + pad=pad, + stride=stride, + dilation=dilation, + group=groups) + output = conv3d_op(input, weight) + + if bias is not None: + output = ops.bias_add(output, bias) + return output - """ - raise ValueError("Requires mindspore >= 2.3.0 by default, or set into pyboost mode by calling torch.config.set_byboost(True).") def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): x_2d = input.unsqueeze(2) # (batch, in_channels, 1, L_in) @@ -903,38 +880,104 @@ def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_paddi # 4. 移除高度维度恢复一维 return output_2d.squeeze(2) +def _deconv_output_length(pad_mode, filter_size, stride_size, dilation_size, padding): + """Calculate the width and height of output.""" + length = 0 + filter_size = filter_size + (filter_size - 1) * (dilation_size - 1) + if pad_mode == 'valid': + if filter_size - stride_size > 0: + length = filter_size - stride_size + elif pad_mode == 'pad': + length = - padding + filter_size - stride_size + + return length + def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): - return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation) + if use_pyboost(): + return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation) -def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): - in_channel, out_channel = weight.shape[0], weight.shape[1] + pad_mode = 'pad' + pad = padding + if isinstance(padding, tuple): + pad = (0, 0, padding[0], padding[0]) + elif isinstance(padding, int): + pad = (0, 0) + (padding,) * 2 + if not isinstance(padding, (int, tuple)): + pad_mode = padding + pad = (0,) * 4 + + in_channel, out_channels = weight.shape[0], weight.shape[1] * groups kernel_size = weight.shape[2:] - conv_transpose3d_op = ops.Conv3DTranspose( - in_channel, - out_channel, - kernel_size, - mode=1, - pad_mode='valid', - pad=padding, - stride=stride, - dilation=dilation, - group=1, - output_padding=output_padding, - data_format="NCDHW" - ) - if groups > 1: - outputs = () - for i in range(groups): - output = conv_transpose3d_op(input.half(), weight.half()) + + conv2d_transpose_op = ops.Conv2DTranspose(out_channel=out_channels, + kernel_size=kernel_size, + mode=1, + pad_mode=pad_mode, + pad=pad, + stride=stride, + dilation=dilation, + group=groups) + n, _, h, w = input.shape + h_add = _deconv_output_length(pad_mode, kernel_size[0], stride[0], dilation[0], pad[0] + pad[1]) + w_add = _deconv_output_length(pad_mode, kernel_size[1], stride[1], dilation[1], pad[2] + pad[3]) + + out = conv2d_transpose_op(input, weight, + (n, out_channels, h * stride[0] + h_add, w * stride[1] + w_add)) + if bias is not None: + out = ops.bias_add(out, bias) + return out + +def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if input.device.type == 'npu': + in_channel, out_channel = weight.shape[0], weight.shape[1] + kernel_size = weight.shape[2:] + conv_transpose3d_op = ops.Conv3DTranspose( + in_channel, + out_channel, + kernel_size, + mode=1, + pad_mode='valid', + pad=padding, + stride=stride, + dilation=dilation, + group=1, + output_padding=output_padding, + data_format="NCDHW" + ) + if groups > 1: + outputs = () + for i in range(groups): + output = conv_transpose3d_op(input.half(), weight.half()) + if bias is not None: + output = output + bias + outputs = outputs + (output,) + out = ops.concat(outputs, 1) + else: + out = conv_transpose3d_op(input, weight) if bias is not None: - output = output + bias - outputs = outputs + (output,) - out = ops.concat(outputs, 1) + out = out + bias + return out else: + in_channel, out_channel = weight.shape[0], weight.shape[1] * groups + kernel_size = weight.shape[2:] + conv_transpose3d_op = ops.Conv3DTranspose( + in_channel, + out_channel, + kernel_size, + mode=1, + pad_mode='valid', + pad=padding, + stride=stride, + dilation=dilation, + group=groups, + output_padding=output_padding, + data_format="NCDHW" + ) + out = conv_transpose3d_op(input, weight) if bias is not None: out = out + bias - return out + return out def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py index 08d2a4791..51e80049f 100644 --- a/mindnlp/core/ops/creation.py +++ b/mindnlp/core/ops/creation.py @@ -212,7 +212,9 @@ def empty(*size, dtype=None, device=None, requires_grad=False, pin_memory=False, # empty_like has_empty_like = hasattr(mindspore.mint, 'empty_like') def empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None): - return mindspore.mint.empty_like(input, dtype=dtype, device=device) + if use_pyboost(): + return mindspore.mint.empty_like(input, dtype=dtype, device=device) + return mindspore.Tensor(np.empty(input.shape, dtype=dtype2np[input.dtype])) # empty_strided diff --git a/mindnlp/core/ops/random.py b/mindnlp/core/ops/random.py index d8fad826e..2e8d993bb 100644 --- a/mindnlp/core/ops/random.py +++ b/mindnlp/core/ops/random.py @@ -9,6 +9,7 @@ from .pointwise import div, log from .._bind import get_default_dtype from ._inner import call_ms_func +from .._C import default_generator # bernoulli has_bernoulli = hasattr(mindspore.mint, 'bernoulli') @@ -117,6 +118,8 @@ def randint_like(*args, **kwargs): # randn has_randn = hasattr(mindspore.mint, 'randn') def randn(*size, generator=None, dtype=None, **kwargs): + if isinstance(size[0], tuple): + size = size[0] size = kwargs.pop('size', size) new_size = () for s in size: @@ -128,7 +131,11 @@ def randn(*size, generator=None, dtype=None, **kwargs): if use_pyboost() and has_randn: return mindspore.mint.randn(*new_size, generator=generator, dtype=dtype) # return ops.randn(*new_size, dtype=dtype) - return mindspore.Tensor(np.random.randn(*new_size), dtype=dtype) + if not generator: + generator = default_generator + seed, _ = generator._step(12) + rng = np.random.default_rng(seed.item()) + return mindspore.Tensor(rng.standard_normal(new_size), dtype=dtype) # randn_like has_randn_like = hasattr(mindspore.mint, 'randn_like') diff --git a/mindnlp/core/random.py b/mindnlp/core/random.py index bfa2c17bb..0612cb2a0 100644 --- a/mindnlp/core/random.py +++ b/mindnlp/core/random.py @@ -3,6 +3,7 @@ import warnings from typing import Generator +import mindspore from mindnlp import core # from mindspore import default_generator, set_seed @@ -53,6 +54,7 @@ def manual_seed(seed): is raised. Negative inputs are remapped to positive values with the formula `0xffff_ffff_ffff_ffff + seed`. """ + mindspore.set_seed(seed) seed = int(seed) # set_seed(seed) return default_generator.manual_seed(seed)