Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#43 from mthreads/fix_blas
Browse files Browse the repository at this point in the history
[MTAI-484] fix(build): modify format for MUSA
  • Loading branch information
caizhi-mt authored and mt-robot committed Aug 14, 2023
2 parents 48cc622 + 5a56ea0 commit 1ee924d
Show file tree
Hide file tree
Showing 13 changed files with 227 additions and 393 deletions.
38 changes: 18 additions & 20 deletions paddle/phi/backends/gpu/gpu_resources.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ void InitGpuProperties(Place place,
"version.";
}
#elif defined(PADDLE_WITH_MUSA)
// TODO(@caizhi): enable dynload module
// TODO(@caizhi): mudnnGetVersion is not supported for MUSA now.
// Requests have been submitted to Mudnn.
// size_t mudnn_dso_ver = dynload::mudnnGetVersion();
size_t mudnn_dso_ver = 0;
size_t mudnn_dso_ver = 1100;
LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place.device)
<< ", muDNN Version: " << mudnn_dso_ver / 1000 << "."
<< (mudnn_dso_ver % 1000) / 100 << ".";
Expand All @@ -168,21 +169,20 @@ void InitGpuProperties(Place place,
auto compile_musa_version =
(MUSA_VERSION / 1000) * 10 + (MUSA_VERSION % 100) / 10;
#if defined(__linux__)
// TODO(@caizhi): enable dynload module
//PADDLE_ENFORCE_EQ(
// (local_musa_version / 10 < compile_musa_version / 10) &&
// (mudnn_dso_ver / 1000 < MUDNN_VERSION / 1000),
// false,
// phi::errors::InvalidArgument(
// "The installed Paddle is compiled with MUDA%d/muDNN%d,"
// "but MUSA/muDNN version in your machine is MUSA%d/muDNN%d. "
// "which will cause serious incompatible bug. "
// "Please recompile or reinstall Paddle with compatible MUSA/muDNN "
// "version.",
// compile_musa_version / 10,
// MUDNN_VERSION / 1000,
// local_musa_version / 10,
// mudnn_dso_ver / 1000));
PADDLE_ENFORCE_EQ(
(local_musa_version / 10 < compile_musa_version / 10) &&
(mudnn_dso_ver / 1000 < MUDNN_VERSION / 1000),
false,
phi::errors::InvalidArgument(
"The installed Paddle is compiled with MUDA%d/muDNN%d,"
"but MUSA/muDNN version in your machine is MUSA%d/muDNN%d. "
"which will cause serious incompatible bug. "
"Please recompile or reinstall Paddle with compatible MUSA/muDNN "
"version.",
compile_musa_version / 10,
MUDNN_VERSION / 1000,
local_musa_version / 10,
mudnn_dso_ver / 1000));
#endif
if (local_musa_version < compile_musa_version) {
LOG_FIRST_N(WARNING, 1)
Expand Down Expand Up @@ -335,9 +335,7 @@ void InitDnnHandle(dnnHandle_t* handle, gpuStream_t stream, Place place) {
}
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(handle));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetStream(*handle, stream));
#elif defined(PADDLE_WITH_MUSA)

#else
#elif defined(PADDLE_WITH_CUDA)
auto local_cudnn_version = phi::dynload::cudnnGetVersion() / 100;
auto compile_cudnn_version = CUDNN_VERSION / 100;
if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
Expand Down
38 changes: 18 additions & 20 deletions paddle/phi/kernels/funcs/math_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -430,16 +430,15 @@ void ColwiseSum<phi::GPUContext, double>::operator()(

SetConstant<phi::GPUContext, double> set;
set(context, &one, static_cast<double>(1.0));
// TODO(@caizhi): enable blas modules
//phi::funcs::GetBlas<phi::GPUContext, double>(context).GEMV(
// true,
// static_cast<int>(in_dims[0]),
// static_cast<int>(in_dims[1]),
// 1.0,
// input.data<double>(),
// one.data<double>(),
// 0.0,
// vector->data<double>());
phi::funcs::GetBlas<phi::GPUContext, double>(context).GEMV(
true,
static_cast<int>(in_dims[0]),
static_cast<int>(in_dims[1]),
1.0,
input.data<double>(),
one.data<double>(),
0.0,
vector->data<double>());
}

template struct RowwiseSum<phi::GPUContext, float>;
Expand Down Expand Up @@ -469,16 +468,15 @@ void RowwiseSum<phi::GPUContext, double>::operator()(

SetConstant<phi::GPUContext, double> set;
set(context, &one, static_cast<double>(1.0));
// TODO(@caizhi): enable blas modules
//phi::funcs::GetBlas<phi::GPUContext, double>(context).GEMV(
// true,
// static_cast<int>(in_dims[1]),
// static_cast<int>(in_dims[0]),
// 1.0,
// one.data<double>(),
// input.data<double>(),
// 0.0,
// vector->data<double>());
phi::funcs::GetBlas<phi::GPUContext, double>(context).GEMV(
true,
static_cast<int>(in_dims[1]),
static_cast<int>(in_dims[0]),
1.0,
one.data<double>(),
input.data<double>(),
0.0,
vector->data<double>());
}

template struct RowwiseMean<phi::GPUContext, float>;
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/funcs/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */

namespace phi {
namespace funcs {

using ScopedTensorDescriptor = phi::backends::gpu::ScopedTensorDescriptor;
using DataLayout = phi::backends::gpu::DataLayout;
template <typename T>
Expand Down Expand Up @@ -117,8 +118,6 @@ void SoftmaxGradCUDNNFunctor<T, DeviceContext>::operator()(
MIOPEN_SOFTMAX_ACCURATE,
MIOPEN_SOFTMAX_MODE_INSTANCE));
#elif defined(PADDLE_WITH_MUSA)
// TODO
#else
cudnnTensorDescriptor_t cudnn_y_desc =
yDesc.descriptor<T>(layout, cudnn_tensor_dims);
cudnnTensorDescriptor_t cudnn_xgrad_desc =
Expand Down Expand Up @@ -154,6 +153,7 @@ template class SoftmaxGradCUDNNFunctor<phi::dtype::bfloat16, phi::GPUContext>;
template class SoftmaxCUDNNFunctor<double, phi::GPUContext>;
template class SoftmaxGradCUDNNFunctor<double, phi::GPUContext>;
#endif

template class SoftmaxFunctor<phi::GPUContext, phi::dtype::float16>;
template class SoftmaxFunctor<phi::GPUContext, phi::dtype::bfloat16>;
template class SoftmaxFunctor<phi::GPUContext, float>;
Expand Down
45 changes: 7 additions & 38 deletions paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ void BatchNormGradRawKernel(const Context &ctx,
scale.dims()[0]));

auto dtype = phi::backends::gpu::CudnnDataType<T>::type;
#if defined(PADDLE_WITH_HIP) || defined(PADDLE_WITH_MUSA)
#ifdef PADDLE_WITH_HIP
auto compute_format =
data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW;

Expand Down Expand Up @@ -650,8 +650,7 @@ void BatchNormGradRawKernel(const Context &ctx,
// platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#elif defined(PADDLE_WITH_MUSA)
#else
#elif defined(PADDLE_WITH_CUDA)
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_;
Expand Down Expand Up @@ -697,16 +696,7 @@ void BatchNormGradRawKernel(const Context &ctx,
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenDeriveBNTensorDescriptor(bn_param_desc_,
// data_desc_, mode_));
#elif defined(PADDLE_WITH_MUSA)
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::mudnnSetTensorNdDescriptor(
data_desc_,
CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4,
dims.data(),
strides.data()));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::mudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
#else
#elif defined(PADDLE_WITH_CUDA)
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetTensorNdDescriptor(
data_desc_,
CudnnDataType<T>::type,
Expand Down Expand Up @@ -789,9 +779,7 @@ void BatchNormGradRawKernel(const Context &ctx,
// d_bias->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data));
#elif defined(PADDLE_WITH_MUSA)

#else
#elif defined(PADDLE_WITH_CUDA)
}
// CUDNN only support small batch size
bool use_native_nhwc =
Expand Down Expand Up @@ -1127,12 +1115,7 @@ void BatchNormGradRawKernel(const Context &ctx,
// platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#elif defined(PADDLE_WITH_MUSA)
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::mudnnDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::mudnnDestroyTensorDescriptor(bn_param_desc_));
#else
#elif defined(PADDLE_WITH_CUDA)
// clean when exit.
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cudnnDestroyTensorDescriptor(data_desc_));
Expand Down Expand Up @@ -1392,21 +1375,7 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
phi::BatchNormGradRawKernel,
float,
phi::dtype::float16) {}
#elif defined(PADDLE_WITH_MUSA)
PD_REGISTER_KERNEL(batch_norm_grad,
GPU,
ALL_LAYOUT,
phi::BatchNormGradKernel,
float,
phi::dtype::float16) {}

PD_REGISTER_KERNEL(batch_norm_grad_raw,
GPU,
ALL_LAYOUT,
phi::BatchNormGradRawKernel,
float,
phi::dtype::float16) {}
#else
#elif defined(PADDLE_WITH_CUDA)
#if CUDNN_VERSION_MIN(8, 1, 0)

PD_REGISTER_KERNEL(batch_norm_grad,
Expand Down Expand Up @@ -1440,7 +1409,7 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
}
#else // CUDA & MUSA
#else // CUDA
PD_REGISTER_KERNEL(batch_norm_grad,
GPU,
ALL_LAYOUT,
Expand Down
14 changes: 6 additions & 8 deletions paddle/phi/kernels/gpu/batch_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ void BatchNormKernel(const Context &ctx,

auto dtype = phi::backends::gpu::CudnnDataType<T>::type;

#if defined(PADDLE_WITH_HIP) || defined(PADDLE_WITH_MUSA)
#ifdef PADDLE_WITH_HIP
auto compute_format =
data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW;

Expand Down Expand Up @@ -597,6 +597,7 @@ void BatchNormKernel(const Context &ctx,
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#elif defined(PADDLE_WITH_MUSA)

#else
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
Expand All @@ -615,9 +616,11 @@ void BatchNormKernel(const Context &ctx,
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);

#if defined(PADDLE_WITH_HIP) || defined(PADDLE_WITH_MUSA)
#ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// mode_ = miopenBNSpatial;
#elif defined(PADDLE_WITH_MUSA)

#elif CUDNN_VERSION_MIN(7, 0, 1)
if (FLAGS_cudnn_batchnorm_spatial_persistent) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
Expand Down Expand Up @@ -1210,12 +1213,7 @@ void BatchNormKernel(const Context &ctx,
// platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#elif defined(PADDLE_WITH_MUSA)
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::mudnnDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::mudnnDestroyTensorDescriptor(bn_param_desc_));
#else
#elif defined(PADDLE_WITH_CUDA)
// clean when exit.
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cudnnDestroyTensorDescriptor(data_desc_));
Expand Down
35 changes: 4 additions & 31 deletions paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,7 @@ void InstanceNormGradKernel(const Context &dev_ctx,
phi::dynload::miopenCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenCreateTensorDescriptor(&in_param_desc_));
#elif defined(PADDLE_WITH_MUSA)
mudnnTensorDescriptor_t data_desc_;
mudnnTensorDescriptor_t in_param_desc_;

PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::mudnnCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::mudnnCreateTensorDescriptor(&in_param_desc_));
#else
#elif defined(PADDLE_WITH_CUDA)
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t in_param_desc_;

Expand All @@ -435,16 +427,7 @@ void InstanceNormGradKernel(const Context &dev_ctx,
const_cast<int *>(strides.data())));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenDeriveBNTensorDescriptor(
in_param_desc_, data_desc_, miopenBNSpatial));
#elif defined(PADDLE_WITH_MUSA)
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::mudnnSetTensorDescriptor(
data_desc_,
CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4,
const_cast<int *>(dims.data()),
const_cast<int *>(strides.data())));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::mudnnDeriveBNTensorDescriptor(
in_param_desc_, data_desc_, miopenBNSpatial));
#else
#elif defined(PADDLE_WITH_CUDA)
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetTensorNdDescriptor(
data_desc_,
CudnnDataType<T>::type,
Expand Down Expand Up @@ -481,14 +464,9 @@ void InstanceNormGradKernel(const Context &dev_ctx,
epsilon,
saved_mean_data,
saved_var_data));
#else
#ifdef PADDLE_WITH_MUSA
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::mudnnBatchNormalizationBackward(
dev_ctx.mudnn_handle(),
#else
#elif defined(PADDLE_WITH_CUDA)
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(),
#endif
CUDNN_BATCHNORM_SPATIAL,
CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(),
Expand Down Expand Up @@ -533,12 +511,7 @@ void InstanceNormGradKernel(const Context &dev_ctx,
phi::dynload::miopenDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenDestroyTensorDescriptor(in_param_desc_));
#elif defined(PADDLE_WITH_MUSA)
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::mudnnDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::mudnnDestroyTensorDescriptor(in_param_desc_));
#else
#elif defined(PADDLE_WITH_CUDA)
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cudnnDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
Expand Down

0 comments on commit 1ee924d

Please sign in to comment.