diff --git a/source/source_esolver/esolver_ks_lcao_tddft.cpp b/source/source_esolver/esolver_ks_lcao_tddft.cpp
index d608082ddd..3860856bd7 100644
--- a/source/source_esolver/esolver_ks_lcao_tddft.cpp
+++ b/source/source_esolver/esolver_ks_lcao_tddft.cpp
@@ -318,7 +318,7 @@ void ESolver_KS_LCAO_TDDFT
::iter_finish(UnitCell& ucell,
if (conv_esolver && estep == estep_max - 1 && istep >= (PARAM.inp.init_wfc == "file" ? 0 : 1)
&& PARAM.inp.td_edm == 0)
{
- elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, this->p_hamilt);
+ elecstate::cal_edm_tddft_tensor(this->pv, this->dmat, this->kv, this->p_hamilt);
}
}
diff --git a/source/source_estate/module_dm/cal_edm_tddft.cpp b/source/source_estate/module_dm/cal_edm_tddft.cpp
index 46168a56e0..c09f0349b7 100644
--- a/source/source_estate/module_dm/cal_edm_tddft.cpp
+++ b/source/source_estate/module_dm/cal_edm_tddft.cpp
@@ -1,10 +1,56 @@
#include "cal_edm_tddft.h"
+#include "source_base/module_container/ATen/core/tensor.h" // For ct::Tensor
+#include "source_base/module_container/ATen/kernels/blas.h"
+#include "source_base/module_container/ATen/kernels/lapack.h"
+#include "source_base/module_container/ATen/kernels/memory.h" // memory operations (Tensor)
+#include "source_base/module_device/memory_op.h" // memory operations
#include "source_base/module_external/lapack_connector.h"
#include "source_base/module_external/scalapack_connector.h"
#include "source_io/module_parameter/parameter.h" // use PARAM.globalv
+
namespace elecstate
{
+void print_local_matrix(std::ostream& os,
+ const std::complex* matrix_data,
+ int local_rows, // pv.nrow
+ int local_cols, // pv.ncol
+ const std::string& matrix_name = "",
+ int rank = -1)
+{
+ if (!matrix_name.empty() || rank >= 0)
+ {
+ os << "=== ";
+ if (!matrix_name.empty())
+ {
+ os << "Matrix: " << matrix_name;
+ if (rank >= 0)
+ os << " ";
+ }
+ if (rank >= 0)
+ {
+ os << "(Process: " << rank + 1 << ")";
+ }
+ os << " (Local dims: " << local_rows << " x " << local_cols << ") ===" << std::endl;
+ }
+
+ os << std::fixed << std::setprecision(10) << std::showpos;
+
+ for (int i = 0; i < local_rows; ++i) // Iterate over rows (i)
+ {
+ for (int j = 0; j < local_cols; ++j) // Iterate over columns (j)
+ {
+ // For column-major storage, element (i, j) is at index i + j * LDA
+ // where LDA (leading dimension) is typically the number of *rows* in the local block.
+ int idx = i + j * local_rows;
+ os << "(" << std::real(matrix_data[idx]) << "," << std::imag(matrix_data[idx]) << ") ";
+ }
+ os << std::endl; // New line after each row
+ }
+ os.unsetf(std::ios_base::fixed | std::ios_base::showpos);
+ os << std::endl;
+}
+
// use the original formula (Hamiltonian matrix) to calculate energy density matrix
void cal_edm_tddft(Parallel_Orbitals& pv,
LCAO_domain::Setup_DM>& dmat,
@@ -252,4 +298,260 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
ModuleBase::timer::tick("elecstate", "cal_edm_tddft");
return;
} // cal_edm_tddft
+
+void cal_edm_tddft_tensor(Parallel_Orbitals& pv,
+ LCAO_domain::Setup_DM>& dmat,
+ K_Vectors& kv,
+ hamilt::Hamilt>* p_hamilt)
+{
+ ModuleBase::timer::tick("elecstate", "cal_edm_tddft_tensor");
+
+ const int nlocal = PARAM.globalv.nlocal;
+ assert(nlocal >= 0);
+ dmat.dm->EDMK.resize(kv.get_nks());
+
+ for (int ik = 0; ik < kv.get_nks(); ++ik)
+ {
+ p_hamilt->updateHk(ik);
+ std::complex* tmp_dmk = dmat.dm->get_DMK_pointer(ik);
+ ModuleBase::ComplexMatrix& tmp_edmk = dmat.dm->EDMK[ik];
+
+#ifdef __MPI
+ const int nloc = pv.nloc;
+ const int ncol = pv.ncol;
+ const int nrow = pv.nrow;
+
+ // Initialize EDMK matrix
+ tmp_edmk.create(ncol, nrow);
+
+ // Allocate Tensor objects on CPU
+ ct::Tensor Htmp_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc}));
+ Htmp_tensor.zero();
+
+ ct::Tensor Sinv_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc}));
+ Sinv_tensor.zero();
+
+ ct::Tensor tmp1_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc}));
+ tmp1_tensor.zero();
+
+ ct::Tensor tmp2_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc}));
+ tmp2_tensor.zero();
+
+ ct::Tensor tmp3_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc}));
+ tmp3_tensor.zero();
+
+ ct::Tensor tmp4_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc}));
+ tmp4_tensor.zero();
+
+ // Get raw pointers from tensors for ScaLAPACK calls
+ std::complex* Htmp_ptr = Htmp_tensor.data>();
+ std::complex* Sinv_ptr = Sinv_tensor.data>();
+ std::complex* tmp1_ptr = tmp1_tensor.data>();
+ std::complex* tmp2_ptr = tmp2_tensor.data>();
+ std::complex* tmp3_ptr = tmp3_tensor.data>();
+ std::complex* tmp4_ptr = tmp4_tensor.data>();
+
+ const int inc = 1;
+ hamilt::MatrixBlock> h_mat;
+ hamilt::MatrixBlock> s_mat;
+ p_hamilt->matrix(h_mat, s_mat);
+
+ // Copy Hamiltonian and Overlap matrices into Tensor buffers using BlasConnector
+ BlasConnector::copy(nloc, h_mat.p, inc, Htmp_ptr, inc);
+ BlasConnector::copy(nloc, s_mat.p, inc, Sinv_ptr, inc);
+
+ // --- ScaLAPACK Inversion of S ---
+ ct::Tensor ipiv_tensor(ct::DataType::DT_INT,
+ ct::DeviceType::CpuDevice,
+ ct::TensorShape({pv.nrow + pv.nb})); // Size for ScaLAPACK pivot array
+ ipiv_tensor.zero();
+ int* ipiv_ptr = ipiv_tensor.data();
+
+ int info = 0;
+ const int one_int = 1;
+ ScalapackConnector::getrf(nlocal, nlocal, Sinv_ptr, one_int, one_int, pv.desc, ipiv_ptr, &info);
+
+ int lwork = -1;
+ int liwork = -1;
+ ct::Tensor work_query_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({1}));
+ ct::Tensor iwork_query_tensor(ct::DataType::DT_INT, ct::DeviceType::CpuDevice, ct::TensorShape({1}));
+
+ ScalapackConnector::getri(nlocal,
+ Sinv_ptr,
+ one_int,
+ one_int,
+ pv.desc,
+ ipiv_ptr,
+ work_query_tensor.data>(),
+ &lwork,
+ iwork_query_tensor.data(),
+ &liwork,
+ &info);
+
+ // Resize work arrays based on query results
+ lwork = work_query_tensor.data>()[0].real();
+ work_query_tensor.resize(ct::TensorShape({lwork}));
+ liwork = iwork_query_tensor.data()[0];
+ iwork_query_tensor.resize(ct::TensorShape({liwork}));
+
+ ScalapackConnector::getri(nlocal,
+ Sinv_ptr,
+ one_int,
+ one_int,
+ pv.desc,
+ ipiv_ptr,
+ work_query_tensor.data>(),
+ &lwork,
+ iwork_query_tensor.data(),
+ &liwork,
+ &info);
+
+ // --- EDM Calculation using ScaLAPACK ---
+ const char N_char = 'N';
+ const char T_char = 'T';
+ const std::complex one_complex = {1.0, 0.0};
+ const std::complex zero_complex = {0.0, 0.0};
+ const std::complex half_complex = {0.5, 0.0};
+
+ // tmp1 = Sinv * Htmp (result stored in tmp1)
+ ScalapackConnector::gemm(N_char,
+ N_char,
+ nlocal,
+ nlocal,
+ nlocal,
+ one_complex,
+ Sinv_ptr,
+ one_int,
+ one_int,
+ pv.desc,
+ Htmp_ptr,
+ one_int,
+ one_int,
+ pv.desc,
+ zero_complex,
+ tmp1_ptr,
+ one_int,
+ one_int,
+ pv.desc);
+
+ // tmp2 = tmp1 * tmp_dmk (result stored in tmp2)
+ ScalapackConnector::gemm(N_char,
+ N_char,
+ nlocal,
+ nlocal,
+ nlocal,
+ one_complex,
+ tmp1_ptr,
+ one_int,
+ one_int,
+ pv.desc,
+ tmp_dmk,
+ one_int,
+ one_int,
+ pv.desc,
+ zero_complex,
+ tmp2_ptr,
+ one_int,
+ one_int,
+ pv.desc);
+
+ // tmp3 = Htmp * Sinv (result stored in tmp3)
+ ScalapackConnector::gemm(N_char,
+ N_char,
+ nlocal,
+ nlocal,
+ nlocal,
+ one_complex,
+ Htmp_ptr,
+ one_int,
+ one_int,
+ pv.desc,
+ Sinv_ptr,
+ one_int,
+ one_int,
+ pv.desc,
+ zero_complex,
+ tmp3_ptr,
+ one_int,
+ one_int,
+ pv.desc);
+
+ // tmp4 = tmp_dmk * tmp3 (result stored in tmp4)
+ ScalapackConnector::gemm(N_char,
+ N_char,
+ nlocal,
+ nlocal,
+ nlocal,
+ one_complex,
+ tmp_dmk,
+ one_int,
+ one_int,
+ pv.desc,
+ tmp3_ptr,
+ one_int,
+ one_int,
+ pv.desc,
+ zero_complex,
+ tmp4_ptr,
+ one_int,
+ one_int,
+ pv.desc);
+
+ // tmp4 = 0.5 * tmp2 + 0.5 * tmp4 (final EDM contribution)
+ ScalapackConnector::geadd(N_char,
+ nlocal,
+ nlocal,
+ half_complex,
+ tmp2_ptr,
+ one_int,
+ one_int,
+ pv.desc,
+ half_complex,
+ tmp4_ptr,
+ one_int,
+ one_int,
+ pv.desc);
+
+ // Copy final result from Tensor buffer back to EDMK matrix
+ BlasConnector::copy(nloc, tmp4_ptr, inc, tmp_edmk.c, inc);
+
+#else
+ // Serial version remains unchanged, using ModuleBase::ComplexMatrix directly
+ tmp_edmk.create(pv.ncol, pv.nrow);
+ ModuleBase::ComplexMatrix Sinv(nlocal, nlocal);
+ ModuleBase::ComplexMatrix Htmp(nlocal, nlocal);
+ hamilt::MatrixBlock> h_mat;
+ hamilt::MatrixBlock> s_mat;
+ p_hamilt->matrix(h_mat, s_mat);
+ for (int i = 0; i < nlocal; i++)
+ {
+ for (int j = 0; j < nlocal; j++)
+ {
+ Htmp(i, j) = h_mat.p[i * nlocal + j];
+ Sinv(i, j) = s_mat.p[i * nlocal + j];
+ }
+ }
+ int INFO = 0;
+ int lwork = 3 * nlocal - 1; // tmp
+ std::complex* work = new std::complex[lwork];
+ ModuleBase::GlobalFunc::ZEROS(work, lwork);
+ int IPIV[nlocal];
+ LapackConnector::zgetrf(nlocal, nlocal, Sinv, nlocal, IPIV, &INFO);
+ LapackConnector::zgetri(nlocal, Sinv, nlocal, IPIV, work, lwork, &INFO);
+ ModuleBase::ComplexMatrix tmp_dmk_base(nlocal, nlocal);
+ for (int i = 0; i < nlocal; i++)
+ {
+ for (int j = 0; j < nlocal; j++)
+ {
+ tmp_dmk_base(i, j) = tmp_dmk[i * nlocal + j];
+ }
+ }
+ tmp_edmk = 0.5 * (Sinv * Htmp * tmp_dmk_base + tmp_dmk_base * Htmp * Sinv);
+ delete[] work;
+#endif
+ } // end ik
+ ModuleBase::timer::tick("elecstate", "cal_edm_tddft_tensor");
+ return;
+} // cal_edm_tddft_tensor
+
} // namespace elecstate
diff --git a/source/source_estate/module_dm/cal_edm_tddft.h b/source/source_estate/module_dm/cal_edm_tddft.h
index 5ceec1bdaa..004c753ba7 100644
--- a/source/source_estate/module_dm/cal_edm_tddft.h
+++ b/source/source_estate/module_dm/cal_edm_tddft.h
@@ -9,8 +9,13 @@
namespace elecstate
{
void cal_edm_tddft(Parallel_Orbitals& pv,
- LCAO_domain::Setup_DM> &dmat,
+ LCAO_domain::Setup_DM>& dmat,
K_Vectors& kv,
hamilt::Hamilt>* p_hamilt);
+
+void cal_edm_tddft_tensor(Parallel_Orbitals& pv,
+ LCAO_domain::Setup_DM>& dmat,
+ K_Vectors& kv,
+ hamilt::Hamilt>* p_hamilt);
} // namespace elecstate
#endif // CAL_EDM_TDDFT_H