From 0e17595484243099990f72cfe0bdd69cd0b63a31 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Tue, 18 Nov 2025 14:27:46 +0800 Subject: [PATCH] Add INPUT parameter of dsp counts --- source/source_base/module_device/memory_op.cpp | 3 ++- .../module_external/blas_connector_matrix.cpp | 17 +++++++++-------- source/source_base/module_fft/fft_dsp.cpp | 3 ++- .../module_parameter/input_parameter.h | 3 +++ source/source_main/driver_run.cpp | 2 +- 5 files changed, 17 insertions(+), 11 deletions(-) diff --git a/source/source_base/module_device/memory_op.cpp b/source/source_base/module_device/memory_op.cpp index 51ff26de4b..ac2549e182 100644 --- a/source/source_base/module_device/memory_op.cpp +++ b/source/source_base/module_device/memory_op.cpp @@ -5,6 +5,7 @@ #ifdef __DSP #include "source_base/kernels/dsp/dsp_connector.h" #include "source_base/global_variable.h" +#include "source_io/module_parameter/parameter.h" #endif #include @@ -452,7 +453,7 @@ struct resize_memory_op_mt { mtfunc::free_ht(arr); } - arr = (FPTYPE*)mtfunc::malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK % 4); + arr = (FPTYPE*)mtfunc::malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK % PARAM.inp.dsp_count); std::string record_string; if (record_in != nullptr) { diff --git a/source/source_base/module_external/blas_connector_matrix.cpp b/source/source_base/module_external/blas_connector_matrix.cpp index cdaddc0b77..62a4e46a43 100644 --- a/source/source_base/module_external/blas_connector_matrix.cpp +++ b/source/source_base/module_external/blas_connector_matrix.cpp @@ -4,6 +4,7 @@ #ifdef __DSP #include "source_base/kernels/dsp/dsp_connector.h" #include "source_base/global_variable.h" +#include "source_io/module_parameter/parameter.h" #endif #ifdef __CUDA @@ -30,7 +31,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons else if (device_type == base_device::AbacusDevice_t::DspDevice){ mtfunc::sgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, - &beta, c, &ldc, GlobalV::MY_RANK % 4); + &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); } #endif #ifdef __CUDA @@ -67,7 +68,7 @@ void BlasConnector::gemm(const char transa, #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - mtfunc::dgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % 4); + mtfunc::dgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); } #endif else if (device_type == base_device::AbacusDevice_t::GpuDevice) @@ -106,7 +107,7 @@ void BlasConnector::gemm(const char transa, #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - mtfunc::cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % 4); + mtfunc::cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); } #endif else if (device_type == base_device::AbacusDevice_t::GpuDevice) @@ -157,7 +158,7 @@ void BlasConnector::gemm(const char transa, #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - mtfunc::zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % 4); + mtfunc::zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); } #endif else if (device_type == base_device::AbacusDevice_t::GpuDevice) @@ -200,7 +201,7 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c else if (device_type == base_device::AbacusDevice_t::DspDevice){ mtfunc::sgemm_mth_(&transb, &transa, &m, &n, &k, &alpha, a, &lda, b, &ldb, - &beta, c, &ldc, GlobalV::MY_RANK % 4); + &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); } #endif #ifdef __CUDA @@ -237,7 +238,7 @@ void BlasConnector::gemm_cm(const char transa, #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - mtfunc::dgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % 4); + mtfunc::dgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); } #endif #ifdef __CUDA @@ -276,7 +277,7 @@ void BlasConnector::gemm_cm(const char transa, #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - mtfunc::cgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % 4); + mtfunc::cgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); } #endif #ifdef __CUDA @@ -327,7 +328,7 @@ void BlasConnector::gemm_cm(const char transa, #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - mtfunc::zgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % 4); + mtfunc::zgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); } #endif #ifdef __CUDA diff --git a/source/source_base/module_fft/fft_dsp.cpp b/source/source_base/module_fft/fft_dsp.cpp index 4535e0d76e..ca4a0cd1e4 100644 --- a/source/source_base/module_fft/fft_dsp.cpp +++ b/source/source_base/module_fft/fft_dsp.cpp @@ -2,6 +2,7 @@ #include "source_base/global_variable.h" #include "source_base/global_function.h" +#include "source_io/module_parameter/parameter.h" #include #include @@ -14,7 +15,7 @@ void FFT_DSP::initfft(int nx_in, int ny_in, int nz_in) this->nx = nx_in; this->ny = ny_in; this->nz = nz_in; - cluster_id = GlobalV::MY_RANK % 4; + cluster_id = GlobalV::MY_RANK % PARAM.inp.dsp_count; nxyz = this->nx * this->ny * this->nz; } template <> diff --git a/source/source_io/module_parameter/input_parameter.h b/source/source_io/module_parameter/input_parameter.h index 9f5d452727..605db6572f 100644 --- a/source/source_io/module_parameter/input_parameter.h +++ b/source/source_io/module_parameter/input_parameter.h @@ -692,5 +692,8 @@ struct Input_para bool of_cd = false; ///< add CD potential or not https://doi.org/10.1103/PhysRevB.98.144302 double of_mCD_alpha = 1.0; /// parameter of modified CD Potential + // ============== #Parameters (25.uncommon hardware) ================= + int dsp_count = 4; /// the count of dsp hardwares in one node + }; #endif diff --git a/source/source_main/driver_run.cpp b/source/source_main/driver_run.cpp index 50eb11631c..1ba2ea9d51 100644 --- a/source/source_main/driver_run.cpp +++ b/source/source_main/driver_run.cpp @@ -129,7 +129,7 @@ void Driver::init_hardware() #ifdef __DSP std::cout << " ** Initializing DSP Hardware..." << std::endl; - mtfunc::dspInitHandle(GlobalV::MY_RANK % 4); + mtfunc::dspInitHandle(GlobalV::MY_RANK % PARAM.inp.dsp_count); #endif }