From 9a5a741ddb5cb91152909a23a6494800e560735d Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Fri, 8 Nov 2024 18:44:19 +0800 Subject: [PATCH 1/3] Fix warnings --- source/module_base/blas_connector.cpp | 31 +++++++------------ .../kernels/cuda/math_kernel_op.cu | 2 +- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 61ea4b390f..58c678e0f2 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -69,17 +69,17 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return sdot_(&n, X, &incX, Y, &incY); + } return sdot_(&n, X, &incX, Y, &incY); } -} double BlasConnector::dot( const int n, const double *X, const int incX, const double *Y, const int incY, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return ddot_(&n, X, &incX, Y, &incY); + } return ddot_(&n, X, &incX, Y, &incY); } -} // C = a * A.? * B.? + b * C void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, @@ -93,7 +93,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - sgemm_mth_(&transb, &transa, &n, &m, &k, + sgemm_mt_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -111,7 +111,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - dgemm_mth_(&transb, &transa, &n, &m, &k, + dgemm_mt_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -129,7 +129,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - cgemm_mth_(&transb, &transa, &n, &m, &k, + cgemm_mt_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -147,22 +147,13 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - zgemm_mth_(&transb, &transa, &n, &m, &k, + zgemm_mt_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } #endif } -void BlasConnector::gemv(const char trans, const int m, const int n, - const float alpha, const float* A, const int lda, const float* X, const int incx, - const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); -} -} - void BlasConnector::gemv(const char trans, const int m, const int n, const double alpha, const double* A, const int lda, const double* X, const int incx, const double beta, double* Y, const int incy, base_device::AbacusDevice_t device_type) @@ -196,39 +187,39 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return snrm2_( &n, X, &incX ); + } return snrm2_( &n, X, &incX ); } -} double BlasConnector::nrm2( const int n, const double *X, const int incX, base_device::AbacusDevice_t device_type ) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return dnrm2_( &n, X, &incX ); + } return dnrm2_( &n, X, &incX ); } -} double BlasConnector::nrm2( const int n, const std::complex *X, const int incX, base_device::AbacusDevice_t device_type ) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return dznrm2_( &n, X, &incX ); + } return dznrm2_( &n, X, &incX ); } -} // copies a into b void BlasConnector::copy(const long n, const double *a, const int incx, double *b, const int incy, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { dcopy_(&n, a, &incx, b, &incy); -} + } } void BlasConnector::copy(const long n, const std::complex *a, const int incx, std::complex *b, const int incy, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { zcopy_(&n, a, &incx, b, &incy); -} + } } \ No newline at end of file diff --git a/source/module_hsolver/kernels/cuda/math_kernel_op.cu b/source/module_hsolver/kernels/cuda/math_kernel_op.cu index c5a49b85e3..930ac0b3ce 100644 --- a/source/module_hsolver/kernels/cuda/math_kernel_op.cu +++ b/source/module_hsolver/kernels/cuda/math_kernel_op.cu @@ -12,7 +12,7 @@ namespace hsolver { const int warp_size = 32; -const unsigned int full_mask = 0xffffffff; +//const unsigned int full_mask = 0xffffffff; const int thread_per_block = 256; } From 0e0d7fb80051284857a5f0a94da47eed52f3dc5e Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Fri, 8 Nov 2024 18:49:44 +0800 Subject: [PATCH 2/3] Fix cuda compiling bug --- source/module_hsolver/kernels/cuda/math_kernel_op.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/source/module_hsolver/kernels/cuda/math_kernel_op.cu b/source/module_hsolver/kernels/cuda/math_kernel_op.cu index 930ac0b3ce..6185433895 100644 --- a/source/module_hsolver/kernels/cuda/math_kernel_op.cu +++ b/source/module_hsolver/kernels/cuda/math_kernel_op.cu @@ -12,7 +12,7 @@ namespace hsolver { const int warp_size = 32; -//const unsigned int full_mask = 0xffffffff; +// const unsigned int full_mask = 0xffffffff; const int thread_per_block = 256; } @@ -65,11 +65,11 @@ void destoryBLAShandle(){ } } -template -__forceinline__ __device__ void warp_reduce(FPTYPE& val) { - for (int offset = 16; offset > 0; offset >>= 1) - val += __shfl_down_sync(full_mask, val, offset); -} +// template +// __forceinline__ __device__ void warp_reduce(FPTYPE& val) { +// for (int offset = 16; offset > 0; offset >>= 1) +// val += __shfl_down_sync(full_mask, val, offset); +// } template __global__ void line_minimize_with_block( From 5cab0d07d9a4ddafda4afeed81c8ad4b7bbbf9f3 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Fri, 8 Nov 2024 19:37:25 +0800 Subject: [PATCH 3/3] Fix compiling error --- source/module_base/blas_connector.cpp | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 58c678e0f2..69f51b744f 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -93,7 +93,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - sgemm_mt_(&transb, &transa, &n, &m, &k, + sgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -111,7 +111,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - dgemm_mt_(&transb, &transa, &n, &m, &k, + dgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -129,7 +129,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - cgemm_mt_(&transb, &transa, &n, &m, &k, + cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -147,13 +147,22 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - zgemm_mt_(&transb, &transa, &n, &m, &k, + zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } #endif } +void BlasConnector::gemv(const char trans, const int m, const int n, + const float alpha, const float* A, const int lda, const float* X, const int incx, + const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); +} +} + void BlasConnector::gemv(const char trans, const int m, const int n, const double alpha, const double* A, const int lda, const double* X, const int incx, const double beta, double* Y, const int incy, base_device::AbacusDevice_t device_type)