Skip to content

Commit

Permalink
Refactor&Fix: remove normalization of numerical atomic orbitals in ps…
Browse files Browse the repository at this point in the history
…i_initializer (#3716)

* turn off normalization of pw-represented numerical atomic orbitals in towannier90 interface

* refactor not complete yet

* make code brief and change unittest correspondingly

* change to use STL containers with functions as many as possible to avoid memory leak

* add enough annotations

* turn off normalize

---------

Co-authored-by: wqzhou <33364058+WHUweiqingzhou@users.noreply.github.com>
  • Loading branch information
kirk0830 and WHUweiqingzhou committed Mar 26, 2024
1 parent b447558 commit 7dccf77
Show file tree
Hide file tree
Showing 16 changed files with 1,157 additions and 1,397 deletions.
338 changes: 139 additions & 199 deletions source/module_esolver/esolver_ks_pw.cpp

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion source/module_esolver/esolver_ks_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "./esolver_ks.h"
#include "module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.h"
#include "module_psi/psi_initializer.h"
#include <memory>
#include <module_base/macros.h>

// #include "Basis_PW.h"
Expand Down Expand Up @@ -100,7 +101,12 @@ namespace ModuleESolver
protected:
psi::Psi<std::complex<double>, psi::DEVICE_CPU>* psi = nullptr; //hide the psi in ESolver_KS for tmp use
private:
psi_initializer<T, Device>* psi_init = nullptr;
// psi_initializer<T, Device>* psi_init = nullptr;
// change to use smart pointer to manage the memory, and avoid memory leak
// while the std::make_unique() is not supported till C++14,
// so use the new and std::unique_ptr to manage the memory, but this makes new-delete not symmetric
std::unique_ptr<psi_initializer<T, Device>> psi_init;

Device * ctx = {};
psi::AbacusDevice_t device = {};
psi::Psi<T, Device>* kspw_psi = nullptr;
Expand Down
17 changes: 8 additions & 9 deletions source/module_io/to_wannier90_lcao_in_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,13 @@ void toWannier90_LCAO_IN_PW::calculate(

Structure_Factor* sf_ptr = const_cast<Structure_Factor*>(&sf);
ModulePW::PW_Basis_K* wfcpw_ptr = const_cast<ModulePW::PW_Basis_K*>(wfcpw);
this->psi_init_ = new psi_initializer_nao<std::complex<double>, psi::DEVICE_CPU>();
#ifdef __MPI
this->psi_init_ = new psi_initializer_nao<std::complex<double>, psi::DEVICE_CPU>(
sf_ptr, wfcpw_ptr, &(GlobalC::ucell), &(GlobalC::Pkpoints));
this->psi_init_->initialize(sf_ptr, wfcpw_ptr, &(GlobalC::ucell), &(GlobalC::Pkpoints), 1, nullptr, GlobalV::MY_RANK);
#else
this->psi_init_ = new psi_initializer_nao<std::complex<double>, psi::DEVICE_CPU>(
sf_ptr, wfcpw_ptr, &(GlobalC::ucell));
this->psi_init_->initialize(sf_ptr, wfcpw_ptr, &(GlobalC::ucell), 1, nullptr);
#endif
this->psi_init_->set_orbital_files(GlobalC::ucell.orbital_fn);
this->psi_init_->initialize_only_once();
this->psi_init_->cal_ovlp_flzjlq();
this->psi_init_->tabulate();
this->psi_init_->allocate(true);
read_nnkp(kv);

Expand Down Expand Up @@ -218,14 +215,16 @@ void toWannier90_LCAO_IN_PW::nao_G_expansion(
)
{
int npwx = wfcpw->npwk_max;
psi::Psi<std::complex<double>>* psig = this->psi_init_->cal_psig(ik);
this->psi_init_->proj_ao_onkG(ik);
std::weak_ptr<psi::Psi<std::complex<double>>> psig = this->psi_init_->share_psig();
if(psig.expired()) ModuleBase::WARNING_QUIT("toWannier90_LCAO_IN_PW::nao_G_expansion", "psig is expired");
int nbands = GlobalV::NLOCAL;
int nbasis = npwx*GlobalV::NPOL;
for (int ib = 0; ib < nbands; ib++)
{
for (int ig = 0; ig < nbasis; ig++)
{
psi(ib, ig) = psig[0](ik, ib, ig);
psi(ib, ig) = psig.lock().get()[0](ik, ib, ig);
}
}
}
Expand Down
168 changes: 61 additions & 107 deletions source/module_psi/psi_initializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,145 +6,103 @@
// three global variables definition
#include "module_base/global_variable.h"


template<typename T, typename Device>
#ifdef __MPI
psi_initializer<T, Device>::psi_initializer(Structure_Factor* sf_in, ModulePW::PW_Basis_K* pw_wfc_in, UnitCell* p_ucell_in, Parallel_Kpoints* p_parakpts_in, int random_seed_in)
: sf(sf_in), pw_wfc(pw_wfc_in), p_ucell(p_ucell_in), p_parakpts(p_parakpts_in), random_seed(random_seed_in)
#else
psi_initializer<T, Device>::psi_initializer(Structure_Factor* sf_in, ModulePW::PW_Basis_K* pw_wfc_in, UnitCell* p_ucell_in, int random_seed_in)
: sf(sf_in), pw_wfc(pw_wfc_in), p_ucell(p_ucell_in), random_seed(random_seed_in)
#endif
{
if(this->p_ucell == nullptr) ModuleBase::WARNING_QUIT("psi_initializer", "interface to UnitCell is not valid, quit!");
this->ixy2is = new int[this->pw_wfc->fftnxy];
this->pw_wfc->getfftixy2is(this->ixy2is);
}

template<typename T, typename Device>
psi_initializer<T, Device>::~psi_initializer()
{
delete[] this->ixy2is;
if (this->psig != nullptr)
{
delete this->psig;
this->psig = nullptr;
}
}

template<typename T, typename Device>
psi::Psi<std::complex<double>>* psi_initializer<T, Device>::allocate(bool only_psig)
{
ModuleBase::timer::tick("psi_initializer", "allocate");
/*
WARNING: when basis_type = "pw", the variable GlobalV::NLOCAL will also be set, in this case, it is set to
9 = 1 + 3 + 5, which is the maximal number of orbitals spd, I don't think it is reasonable
The way of calculating this->p_ucell->natomwfc is, for each atom, read pswfc and for s, it is 1, for p, it is 3
The way of calculating this->p_ucell_->natomwfc is, for each atom, read pswfc and for s, it is 1, for p, it is 3
, then multiplied by the number of atoms, and then add them together.
*/

if (this->psig != nullptr)
{
delete this->psig;
this->psig = nullptr;
}
int prefactor = 1;
int nbands_actual = 0;
if(this->method == "random")
if(this->method_ == "random")
{
nbands_actual = GlobalV::NBANDS;
this->nbands_complem = 0;
this->nbands_complem_ = 0;
}
else
{
if(this->method.substr(0, 6) == "atomic")
if(this->method_.substr(0, 6) == "atomic")
{
if(this->p_ucell->natomwfc >= GlobalV::NBANDS)
if(this->p_ucell_->natomwfc >= GlobalV::NBANDS)
{
nbands_actual = this->p_ucell->natomwfc;
this->nbands_complem = 0;
nbands_actual = this->p_ucell_->natomwfc;
this->nbands_complem_ = 0;
}
else
{
nbands_actual = GlobalV::NBANDS;
this->nbands_complem = GlobalV::NBANDS - this->p_ucell->natomwfc;
this->nbands_complem_ = GlobalV::NBANDS - this->p_ucell_->natomwfc;
}
}
else if(this->method.substr(0, 3) == "nao")
else if(this->method_.substr(0, 3) == "nao")
{
/*
previously GlobalV::NLOCAL is used here, however it is wrong. GlobalV::NLOCAL is fixed to 9*nat.
*/
int nbands_local = 0;
for(int it = 0; it < this->p_ucell->ntype; it++)
for(int it = 0; it < this->p_ucell_->ntype; it++)
{
for(int ia = 0; ia < this->p_ucell->atoms[it].na; ia++)
for(int ia = 0; ia < this->p_ucell_->atoms[it].na; ia++)
{
/* FOR EVERY ATOM */
for(int l = 0; l < this->p_ucell->atoms[it].nwl + 1; l++)
for(int l = 0; l < this->p_ucell_->atoms[it].nwl + 1; l++)
{
/* EVERY ZETA FOR (2l+1) ORBS */
/*
non-rotate basis, nbands_local*=2 for GlobalV::NPOL = 2 is enough
*/
//nbands_local += this->p_ucell->atoms[it].l_nchi[l]*(2*l+1) * GlobalV::NPOL;
//nbands_local += this->p_ucell_->atoms[it].l_nchi[l]*(2*l+1) * GlobalV::NPOL;
/*
rotate basis, nbands_local*=4 for p, d, f,... orbitals, and nbands_local*=2 for s orbitals
risky when NSPIN = 4, problematic psi value, needed to be checked
*/
if(l == 0)
{
nbands_local += this->p_ucell->atoms[it].l_nchi[l] * GlobalV::NPOL;
}
else
{
nbands_local += this->p_ucell->atoms[it].l_nchi[l]*(2*l+1) * GlobalV::NPOL;
}

if(l == 0) nbands_local += this->p_ucell_->atoms[it].l_nchi[l] * GlobalV::NPOL;
else nbands_local += this->p_ucell_->atoms[it].l_nchi[l]*(2*l+1) * GlobalV::NPOL;
}
}
}
if(nbands_local >= GlobalV::NBANDS)
{
nbands_actual = nbands_local;
this->nbands_complem = 0;
this->nbands_complem_ = 0;
}
else
{
nbands_actual = GlobalV::NBANDS;
this->nbands_complem = GlobalV::NBANDS - nbands_local;
this->nbands_complem_ = GlobalV::NBANDS - nbands_local;
}
}
}
int nkpts_actual = (GlobalV::CALCULATION == "nscf" && this->mem_saver == 1)?
1 : this->pw_wfc->nks;
int nbasis_actual = this->pw_wfc->npwk_max * GlobalV::NPOL;
int nkpts_actual = (GlobalV::CALCULATION == "nscf" && this->mem_saver_ == 1)? 1 : this->pw_wfc_->nks;
int nbasis_actual = this->pw_wfc_->npwk_max * GlobalV::NPOL;
psi::Psi<std::complex<double>>* psi_out = nullptr;
if(!only_psig)
{
psi_out = new psi::Psi<std::complex<double>>(
nkpts_actual,
GlobalV::NBANDS, // because no matter what, the wavefunction finally needed has GlobalV::NBANDS bands
nbasis_actual,
this->pw_wfc->npwk);
psi_out = new psi::Psi<std::complex<double>>(nkpts_actual,
GlobalV::NBANDS, // because no matter what, the wavefunction finally needed has GlobalV::NBANDS bands
nbasis_actual,
this->pw_wfc_->npwk);
/*
WARNING: this will cause DIRECT MEMORY LEAK, psi is not properly deallocated
*/
const size_t memory_cost_psi =
nkpts_actual*
GlobalV::NBANDS * this->pw_wfc->npwk_max * GlobalV::NPOL*
GlobalV::NBANDS * this->pw_wfc_->npwk_max * GlobalV::NPOL*
sizeof(std::complex<double>);
std::cout << " MEMORY FOR PSI PER PROCESSOR (MB) : " << double(memory_cost_psi)/1024.0/1024.0 << std::endl;
ModuleBase::Memory::record("Psi_PW", memory_cost_psi);
}
this->psig = new psi::Psi<T, Device>(
nkpts_actual,
nbands_actual,
nbasis_actual,
this->pw_wfc->npwk);
this->psig_ = std::make_shared<psi::Psi<T, Device>>(nkpts_actual,
nbands_actual,
nbasis_actual,
this->pw_wfc_->npwk);
const size_t memory_cost_psig =
nkpts_actual*
nbands_actual * this->pw_wfc->npwk_max * GlobalV::NPOL*
nbands_actual * this->pw_wfc_->npwk_max * GlobalV::NPOL*
sizeof(T);
std::cout << " MEMORY FOR AUXILLARY PSI PER PROCESSOR (MB) : " << double(memory_cost_psig)/1024.0/1024.0 << std::endl;

Expand All @@ -155,11 +113,11 @@ psi::Psi<std::complex<double>>* psi_initializer<T, Device>::allocate(bool only_p
<< "nkpts_actual = " << nkpts_actual << "\n"
<< "GlobalV::NBANDS = " << GlobalV::NBANDS << "\n"
<< "nbands_actual = " << nbands_actual << "\n"
<< "nbands_complem = " << this->nbands_complem << "\n"
<< "nbands_complem = " << this->nbands_complem_ << "\n"
<< "nbasis_actual = " << nbasis_actual << "\n"
<< "npwk_max = " << this->pw_wfc->npwk_max << "\n"
<< "npwk_max = " << this->pw_wfc_->npwk_max << "\n"
<< "npol = " << GlobalV::NPOL << "\n";
ModuleBase::Memory::record("PsiG_PW", memory_cost_psig);
ModuleBase::Memory::record("psigPW", memory_cost_psig);
ModuleBase::timer::tick("psi_initializer", "allocate");
return psi_out;
}
Expand All @@ -169,82 +127,78 @@ void psi_initializer<T, Device>::random_t(T* psi, const int iw_start, const int
{
ModuleBase::timer::tick("psi_initializer", "random_t");
assert(iw_start >= 0);
const int ng = this->pw_wfc->npwk[ik];
const int ng = this->pw_wfc_->npwk[ik];
#ifdef __MPI
if (this->random_seed > 0) // qianrui add 2021-8-13
if (this->random_seed_ > 0) // qianrui add 2021-8-13
{
srand(unsigned(this->random_seed + this->p_parakpts->startk_pool[GlobalV::MY_POOL] + ik));
const int nxy = this->pw_wfc->fftnxy;
const int nz = this->pw_wfc->nz;
const int nstnz = this->pw_wfc->nst*nz;
srand(unsigned(this->random_seed_ + this->p_parakpts_->startk_pool[GlobalV::MY_POOL] + ik));
const int nxy = this->pw_wfc_->fftnxy;
const int nz = this->pw_wfc_->nz;
const int nstnz = this->pw_wfc_->nst*nz;

Real *stickrr = new Real[nz];
Real *stickarg = new Real[nz];
Real *tmprr = new Real[nstnz];
Real *tmparg = new Real[nstnz];
std::vector<Real> stickrr(nz);
std::vector<Real> stickarg(nz);
std::vector<Real> tmprr(nstnz);
std::vector<Real> tmparg(nstnz);
for (int iw = iw_start; iw < iw_end; iw++)
{
// get the starting memory address of iw band
T* psi_slice = &(psi[iw * this->pw_wfc->npwk_max * GlobalV::NPOL]);
T* psi_slice = &(psi[iw * this->pw_wfc_->npwk_max * GlobalV::NPOL]);
int startig = 0;
for(int ipol = 0 ; ipol < GlobalV::NPOL ; ++ipol)
{

for(int ir=0; ir < nxy; ir++)
{
if(this->pw_wfc->fftixy2ip[ir] < 0) continue;
if(this->pw_wfc_->fftixy2ip[ir] < 0) continue;
if(GlobalV::RANK_IN_POOL==0)
{
for(int iz=0; iz<nz; iz++)
{
stickrr[ iz ] = std::rand()/Real(RAND_MAX);
stickarg[ iz ] = std::rand()/Real(RAND_MAX);
stickrr[iz] = std::rand()/Real(RAND_MAX);
stickarg[iz] = std::rand()/Real(RAND_MAX);
}
}
stick_to_pool(stickrr, ir, tmprr);
stick_to_pool(stickarg, ir, tmparg);
stick_to_pool(stickrr.data(), ir, tmprr.data());
stick_to_pool(stickarg.data(), ir, tmparg.data());
}

for (int ig = 0;ig < ng;ig++)
{
const double rr = tmprr[this->pw_wfc->getigl2isz(ik,ig)];
const double arg= ModuleBase::TWO_PI * tmparg[this->pw_wfc->getigl2isz(ik,ig)];
const double gk2 = this->pw_wfc->getgk2(ik,ig);
const double rr = tmprr[this->pw_wfc_->getigl2isz(ik,ig)];
const double arg= ModuleBase::TWO_PI * tmparg[this->pw_wfc_->getigl2isz(ik,ig)];
const double gk2 = this->pw_wfc_->getgk2(ik,ig);
psi_slice[ig+startig] = this->template cast_to_T<T>(std::complex<double>(rr*cos(arg)/(gk2 + 1.0), rr*sin(arg)/(gk2 + 1.0)));
}
startig += this->pw_wfc->npwk_max;
startig += this->pw_wfc_->npwk_max;
}
}
delete[] stickrr;
delete[] stickarg;
delete[] tmprr;
delete[] tmparg;
}
else
{
#else // !__MPI
if (this->random_seed > 0) // qianrui add 2021-8-13
if (this->random_seed_ > 0) // qianrui add 2021-8-13
{
srand(unsigned(this->random_seed + ik));
srand(unsigned(this->random_seed_ + ik));
}
#endif
for (int iw = iw_start ;iw < iw_end; iw++)
{
T* psi_slice = &(psi[iw * this->pw_wfc->npwk_max * GlobalV::NPOL]);
T* psi_slice = &(psi[iw * this->pw_wfc_->npwk_max * GlobalV::NPOL]);
for (int ig = 0; ig < ng; ig++)
{
const double rr = std::rand()/double(RAND_MAX); //qianrui add RAND_MAX
const double arg= ModuleBase::TWO_PI * std::rand()/double(RAND_MAX);
const double gk2 = this->pw_wfc->getgk2(ik,ig);
const double gk2 = this->pw_wfc_->getgk2(ik,ig);
psi_slice[ig] = this->template cast_to_T<T>(std::complex<double>(rr*cos(arg)/(gk2 + 1.0), rr*sin(arg)/(gk2 + 1.0)));
}
if(GlobalV::NPOL==2)
{
for (int ig = this->pw_wfc->npwk_max; ig < this->pw_wfc->npwk_max + ng; ig++)
for (int ig = this->pw_wfc_->npwk_max; ig < this->pw_wfc_->npwk_max + ng; ig++)
{
const double rr = std::rand()/double(RAND_MAX);
const double arg= ModuleBase::TWO_PI * std::rand()/double(RAND_MAX);
const double gk2 = this->pw_wfc->getgk2(ik,ig-this->pw_wfc->npwk_max);
const double gk2 = this->pw_wfc_->getgk2(ik,ig-this->pw_wfc_->npwk_max);
psi_slice[ig] = this->template cast_to_T<T>(std::complex<double>(rr*cos(arg)/(gk2 + 1.0), rr*sin(arg)/(gk2 + 1.0)));
}
}
Expand All @@ -261,9 +215,9 @@ void psi_initializer<T, Device>::stick_to_pool(Real* stick, const int& ir, Real*
{
ModuleBase::timer::tick("psi_initializer", "stick_to_pool");
MPI_Status ierror;
const int is = this->ixy2is[ir];
const int ip = this->pw_wfc->fftixy2ip[ir];
const int nz = this->pw_wfc->nz;
const int is = this->ixy2is_[ir];
const int ip = this->pw_wfc_->fftixy2ip[ir];
const int nz = this->pw_wfc_->nz;

if(ip == 0 && GlobalV::RANK_IN_POOL ==0)
{
Expand Down

0 comments on commit 7dccf77

Please sign in to comment.