Skip to content

Commit

Permalink
Feature: Add a new Davidson iteration method called subspace davidson…
Browse files Browse the repository at this point in the history
… for pw basis (#3903)

* add new_dav method which is same to davidson method for pw basis

* add new_dav method for pw basis which is more efficient than origin dav method

* fix ilaenv interface bug

* update new_dav method 3.5

* implement new_dav method for pw basis (cpu version, one core)

* fix bug of dav method for pw basis

* debug for new davidson method

* opt some value setting for new_dav files

* format and reorganize the code

* fix CUDA compile bug

* format diago_newdav.cpp

* Implement multi-core parallelism of the new davidson method

* fix build bug for without mpi

* replace new-dav of subspace-dav

* change file name from diago_newdav to diago_subspacedav

* fix build bug for tests

* change the name of subspacedav to dav_subspace

* fix build bug in Integration Test

---------

Co-authored-by: Mohan Chen <mohan.chen.chen.mohan@gmail.com>
  • Loading branch information
haozhihan and mohanchen committed Apr 9, 2024
1 parent 125581b commit 21d40ce
Show file tree
Hide file tree
Showing 19 changed files with 2,269 additions and 749 deletions.
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ OBJS_HCONTAINER=base_matrix.o\

OBJS_HSOLVER=diago_cg.o\
diago_david.o\
diago_dav_subspace.o\
diagh_consts.o\
diago_bpcg.o\
hsolver_pw.o\
Expand Down
184 changes: 119 additions & 65 deletions source/module_base/lapack_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
extern "C"
{

int ilaenv_(int* ispec,const char* name,const char* opts,
const int* n1,const int* n2,const int* n3,const int* n4);

// solve the generalized eigenproblem Ax=eBx, where A is Hermitian and complex couble
// zhegv_ & zhegvd_ returns all eigenvalues while zhegvx_ returns selected ones
void dsygvd_(const int* itype, const char* jobz, const char* uplo, const int* n,
Expand Down Expand Up @@ -60,9 +63,12 @@ extern "C"
const int* m, double* w, std::complex<double> *z, const int *ldz,
std::complex<double> *work, const int* lwork, double* rwork, int* iwork, int* ifail, int* info);

void zhegv_(const int* itype,const char* jobz,const char* uplo,const int* n,
std::complex<double>* a,const int* lda,std::complex<double>* b,const int* ldb,
double* w,std::complex<double>* work,int* lwork,double* rwork,int* info);

void dsygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
const int* n, double* A, const int* lda, double* B, const int* ldb,
const double* vl, const double* vu, const int* il, const int* iu,
const double* abstol, const int* m, double* w, double* Z, const int* ldz,
double* work, int* lwork, int*iwork, int* ifail, int* info);

void chegvx_(const int* itype,const char* jobz,const char* range,const char* uplo,
const int* n,std::complex<float> *a,const int* lda,std::complex<float> *b,
Expand All @@ -78,6 +84,16 @@ extern "C"
std::complex<double> *z,const int *ldz,std::complex<double> *work,const int* lwork,
double* rwork,int* iwork,int* ifail,int* info);

void zhegv_(const int* itype,const char* jobz,const char* uplo,const int* n,
std::complex<double>* a,const int* lda,std::complex<double>* b,const int* ldb,
double* w,std::complex<double>* work,int* lwork,double* rwork,int* info);
void chegv_(const int* itype,const char* jobz,const char* uplo,const int* n,
std::complex<float>* a,const int* lda,std::complex<float>* b,const int* ldb,
float* w,std::complex<float>* work,int* lwork,float* rwork,int* info);
void dsygv_(const int* itype, const char* jobz,const char* uplo, const int* n,
double* a,const int* lda,double* b,const int* ldb,
double* w,double* work,int* lwork,int* info);

// solve the eigenproblem Ax=ex, where A is Hermitian and complex couble
// zheev_ returns all eigenvalues while zheevx_ returns selected ones
void zheev_(const char* jobz,const char* uplo,const int* n,std::complex<double> *a,
Expand All @@ -86,18 +102,6 @@ extern "C"
void cheev_(const char* jobz,const char* uplo,const int* n,std::complex<float> *a,
const int* lda,float* w,std::complex<float >* work,const int* lwork,
float* rwork,int* info);

// solve the generalized eigenproblem Ax=eBx, where A is Symmetric and real couble
// dsygv_ returns all eigenvalues while dsygvx_ returns selected ones
void dsygv_(const int* itype, const char* jobz,const char* uplo, const int* n,
double* a,const int* lda,double* b,const int* ldb,
double* w,double* work,int* lwork,int* info);
void dsygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
const int* n, double* A, const int* lda, double* B, const int* ldb,
const double* vl, const double* vu, const int* il, const int* iu,
const double* abstol, int* m, double* w, double* Z, const int* ldz,
double* work, int* lwork, int*iwork, int* ifail, int* info);
// solve the eigenproblem Ax=ex, where A is Symmetric and real double
void dsyev_(const char* jobz,const char* uplo,const int* n,double *a,
const int* lda,double* w,double* work,const int* lwork, int* info);

Expand Down Expand Up @@ -314,23 +318,19 @@ class LapackConnector
}

public:
// wrap function of fortran lapack routine zhegvd.
static inline
void zhegvd(const int itype, const char jobz, const char uplo, const int n,
std::complex<double>* a, const int lda,
const std::complex<double>* b, const int ldb, double* w,
std::complex<double>* work, int lwork, double* rwork, int lrwork,
int* iwork, int liwork, int& info)
int ilaenv( int ispec, const char *name,const char *opts,const int n1,const int n2,
const int n3,const int n4)
{
zhegvd_(&itype, &jobz, &uplo, &n,
a, &lda, b, &ldb, w,
work, &lwork, rwork, &lrwork,
iwork, &liwork, &info);
const int nb = ilaenv_(&ispec, name, opts, &n1, &n2, &n3, &n4);
return nb;
}



// wrap function of fortran lapack routine zhegvd. (pointer version)
static inline
void xhegvd(const int itype, const char jobz, const char uplo, const int n,
void xhegvd(const int itype, const char jobz, const char uplo, const int n,
double* a, const int lda,
const double* b, const int ldb, double* w,
double* work, int lwork, double* rwork, int lrwork,
Expand Down Expand Up @@ -373,23 +373,9 @@ class LapackConnector
iwork, &liwork, &info);
}

// wrap function of fortran lapack routine zheevx.
static inline
void zheevx( const int itype, const char jobz, const char range, const char uplo, const int n,
std::complex<double>* a, const int lda,
const double vl, const double vu, const int il, const int iu, const double abstol,
const int m, double* w, std::complex<double>* z, const int ldz,
std::complex<double>* work, const int lwork, double* rwork, int* iwork, int* ifail, int& info)
{
zheevx_(&jobz, &range, &uplo, &n,
a, &lda, &vl, &vu, &il, &iu,
&abstol, &m, w, z, &ldz,
work, &lwork, rwork, iwork, ifail, &info);
}

// wrap function of fortran lapack routine dsyevx.
static inline
void xheevx(const int itype, const char jobz, const char range, const char uplo, const int n,
void xheevx(const int itype, const char jobz, const char range, const char uplo, const int n,
double* a, const int lda,
const double vl, const double vu, const int il, const int iu, const double abstol,
const int m, double* w, double* z, const int ldz,
Expand Down Expand Up @@ -428,6 +414,98 @@ class LapackConnector
&abstol, &m, w, z, &ldz,
work, &lwork, rwork, iwork, ifail, &info);
}

// wrap function of fortran lapack routine xhegvx ( pointer version ).
static inline
void xhegvx( const int itype, const char jobz, const char range, const char uplo,
const int n, std::complex<float>* a, const int lda, std::complex<float>* b,
const int ldb, const float vl, const float vu, const int il, const int iu,
const float abstol, const int m, float* w, std::complex<float>* z, const int ldz,
std::complex<float>* work, const int lwork, float* rwork, int* iwork,
int* ifail, int& info)
{
chegvx_(&itype, &jobz, &range, &uplo, &n, a, &lda, b, &ldb, &vl,
&vu, &il,&iu, &abstol, &m, w, z, &ldz, work, &lwork, rwork, iwork, ifail, &info);
}

// wrap function of fortran lapack routine xhegvx ( pointer version ).
static inline
void xhegvx( const int itype, const char jobz, const char range, const char uplo,
const int n, std::complex<double>* a, const int lda, std::complex<double>* b,
const int ldb, const double vl, const double vu, const int il, const int iu,
const double abstol, const int m, double* w, std::complex<double>* z, const int ldz,
std::complex<double>* work, const int lwork, double* rwork, int* iwork,
int* ifail, int& info)
{
zhegvx_(&itype, &jobz, &range, &uplo, &n, a, &lda, b, &ldb, &vl,
&vu, &il,&iu, &abstol, &m, w, z, &ldz, work, &lwork, rwork, iwork, ifail, &info);
}
// wrap function of fortran lapack routine xhegvx ( pointer version ).
static inline
void xhegvx( const int itype, const char jobz, const char range, const char uplo,
const int n, double* a, const int lda, double* b,
const int ldb, const double vl, const double vu, const int il, const int iu,
const double abstol, const int m, double* w, double* z, const int ldz,
double* work, const int lwork, double* rwork, int* iwork,
int* ifail, int& info)
{
// dsygvx_(&itype, &jobz, &range, &uplo, &n, a, &lda, b, &ldb, &vl,
// &vu, &il,&iu, &abstol, &m, w, z, &ldz, work, &lwork, rwork, iwork, ifail, &info);
}


// wrap function of fortran lapack routine xhegvx ( pointer version ).
static inline
void xhegv( const int itype, const char jobz, const char uplo,
const int n,
double* a, const int lda,
double* b, const int ldb,
double* w,
double* work, int lwork,
double* rwork, int& info)
{
// TODO
}

// wrap function of fortran lapack routine xhegvx ( pointer version ).
static inline
void xhegv( const int itype, const char jobz, const char uplo,
const int n,
std::complex<float>* a, const int lda,
std::complex<float>* b, const int ldb,
float* w,
std::complex<float>* work, int lwork,
float* rwork, int& info)
{
// TODO
}
// wrap function of fortran lapack routine xhegvx ( pointer version ).
static inline
void xhegv( const int itype, const char jobz, const char uplo,
const int n,
std::complex<double>* a, const int lda,
std::complex<double>* b, const int ldb,
double* w,
std::complex<double>* work, int lwork,
double* rwork, int& info)
{
zhegv_(&itype, &jobz, &uplo, &n, a, &lda, b, &ldb, w, work, &lwork, rwork, &info);
}


// wrap function of fortran lapack routine zhegvd.
static inline
void zhegvd(const int itype, const char jobz, const char uplo, const int n,
std::complex<double>* a, const int lda,
const std::complex<double>* b, const int ldb, double* w,
std::complex<double>* work, int lwork, double* rwork, int lrwork,
int* iwork, int liwork, int& info)
{
zhegvd_(&itype, &jobz, &uplo, &n,
a, &lda, b, &ldb, w,
work, &lwork, rwork, &lrwork,
iwork, &liwork, &info);
}

// wrap function of fortran lapack routine zhegv ( ModuleBase::ComplexMatrix version ).
static inline
Expand Down Expand Up @@ -543,30 +621,6 @@ class LapackConnector
delete[] zux;
}

// wrap function of fortran lapack routine xhegvx ( pointer version ).
static inline
void xhegvx( const int itype, const char jobz, const char range, const char uplo,
const int n, const std::complex<float>* a, const int lda, const std::complex<float>* b,
const int ldb, const float vl, const float vu, const int il, const int iu,
const float abstol, const int m, float* w, std::complex<float>* z, const int ldz,
std::complex<float>* work, const int lwork, float* rwork, int* iwork,
int* ifail, int& info, int nbase_x)
{
chegvx(itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, work, lwork, rwork, iwork, ifail, info, nbase_x);
}

// wrap function of fortran lapack routine xhegvx ( pointer version ).
static inline
void xhegvx( const int itype, const char jobz, const char range, const char uplo,
const int n, const std::complex<double>* a, const int lda, const std::complex<double>* b,
const int ldb, const double vl, const double vu, const int il, const int iu,
const double abstol, const int m, double* w, std::complex<double>* z, const int ldz,
std::complex<double>* work, const int lwork, double* rwork, int* iwork,
int* ifail, int& info, int nbase_x)
{
zhegvx(itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, work, lwork, rwork, iwork, ifail, info, nbase_x);
}

// calculate the eigenvalues and eigenfunctions of a real symmetric matrix.
static inline
void dsygv( const int itype,const char jobz,const char uplo,const int n,ModuleBase::matrix& a,
Expand Down
2 changes: 1 addition & 1 deletion source/module_basis/module_ao/ORB_control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ void ORB_control::setup_2d_division(std::ofstream& ofs_running,

// determine whether 2d-division or not according to ks_solver
bool div_2d;
if (ks_solver == "lapack" || ks_solver == "cg" || ks_solver == "dav") div_2d = false;
if (ks_solver == "lapack" || ks_solver == "cg" || ks_solver == "dav" || ks_solver == "dav_subspace") div_2d = false;
#ifdef __MPI
else if (ks_solver == "genelpa" || ks_solver == "scalapack_gvx" || ks_solver == "cusolver" || ks_solver == "cg_in_lcao") div_2d = true;
#endif
Expand Down
4 changes: 4 additions & 0 deletions source/module_elecstate/elecstate_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ void ElecState::print_etot(const bool converged,
{
label = "DA";
}
else if (ks_solver_type == "dav_subspace")
{
label = "DS";
}
else if (ks_solver_type == "scalapack_gvx")
{
label = "GV";
Expand Down
6 changes: 3 additions & 3 deletions source/module_hamilt_pw/hamilt_pwdft/wavefunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ void diago_PAO_in_pw_k2(const psi::DEVICE_GPU *ctx,
//GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data());
}
}
else if(GlobalV::KS_SOLVER=="dav")
else if (GlobalV::KS_SOLVER == "dav" || GlobalV::KS_SOLVER == "dav_subspace")
{
assert(nbands <= wfcatom.nr);
// replace by haozhihan 2022-11-23
Expand Down Expand Up @@ -685,8 +685,8 @@ void diago_PAO_in_pw_k2(const psi::DEVICE_GPU *ctx,
//GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data());
}
}
else if(GlobalV::KS_SOLVER=="dav")
{
else if (GlobalV::KS_SOLVER == "dav" || GlobalV::KS_SOLVER == "dav_subspace")
{
assert(nbands <= wfcatom.nr);
// replace by haozhihan 2022-11-23
hsolver::matrixSetToAnother<std::complex<double>, psi::DEVICE_GPU>()(
Expand Down
1 change: 1 addition & 0 deletions source/module_hsolver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ list(APPEND objects
diagh_consts.cpp
diago_cg.cpp
diago_david.cpp
diago_dav_subspace.cpp
diago_bpcg.cpp
hsolver_pw.cpp
hsolver_pw_sdft.cpp
Expand Down

0 comments on commit 21d40ce

Please sign in to comment.