Skip to content

Commit

Permalink
Refactor: remove some global variables in read_dm and read_wfc_nao (#…
Browse files Browse the repository at this point in the history
…3794)

* remove some global variables in read_dm

* Refactor: remove some global variables in read_wfc_nao.cpp

* Fix: add a parameter  in distri_wfc_nao
  • Loading branch information
dzzz2001 committed Mar 26, 2024
1 parent 869929a commit 3fc4ac2
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 67 deletions.
3 changes: 3 additions & 0 deletions source/module_esolver/esolver_ks_lcao_elec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ void ESolver_KS_LCAO<TK, TR>::beforesolver(const int istep)
this->GridT.nnrg,
this->GridT.trace_lo,
#endif
GlobalV::GAMMA_ONLY_LOCAL,
GlobalV::NLOCAL,
GlobalV::NSPIN,
is,
ssd.str(),
this->LOC.DM,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ void Local_Orbital_wfc::gamma_file(psi::Psi<double>* psid, elecstate::ElecState*

for (int is = 0; is < GlobalV::NSPIN; ++is)
{
this->error = ModuleIO::read_wfc_nao(ctot, is, this->ParaV, psid, pelec);
this->error = ModuleIO::read_wfc_nao(ctot, is, GlobalV::GAMMA_ONLY_LOCAL, GlobalV::NB2D, GlobalV::NBANDS,
GlobalV::NLOCAL, GlobalV::global_readin_dir, this->ParaV, psid, pelec);
#ifdef __MPI
Parallel_Common::bcast_int(this->error);
#endif
Expand Down Expand Up @@ -164,7 +165,8 @@ void Local_Orbital_wfc::allocate_k(const int& lgd,
for (int ik = 0; ik < nkstot; ++ik)
{
std::complex<double>** ctot;
this->error = ModuleIO::read_wfc_nao_complex(ctot, ik, kvec_c[ik], this->ParaV, psi, pelec);
this->error = ModuleIO::read_wfc_nao_complex(ctot, ik, GlobalV::NB2D, GlobalV::NBANDS, GlobalV::NLOCAL,
GlobalV::global_readin_dir, kvec_c[ik], this->ParaV, psi, pelec);
#ifdef __MPI
Parallel_Common::bcast_int(this->error);
#endif
Expand Down
3 changes: 3 additions & 0 deletions source/module_io/dm_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ void read_dm(
const int nnrg,
const int* trace_lo,
#endif
const bool gamma_only_local,
const int nlocal,
const int nspin,
const int &is,
const std::string &fn,
double*** DM,
Expand Down
29 changes: 16 additions & 13 deletions source/module_io/read_dm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ void ModuleIO::read_dm(
const int nnrg,
const int* trace_lo,
#endif
const bool gamma_only_local,
const int nlocal,
const int nspin,
const int &is,
const std::string &fn,
double*** DM,
Expand Down Expand Up @@ -77,22 +80,22 @@ void ModuleIO::read_dm(
}
}

ModuleBase::CHECK_INT(ifs, GlobalV::NSPIN);
ModuleBase::CHECK_INT(ifs, nspin);
ModuleBase::GlobalFunc::READ_VALUE(ifs, ef);
ModuleBase::CHECK_INT(ifs, GlobalV::NLOCAL);
ModuleBase::CHECK_INT(ifs, GlobalV::NLOCAL);
ModuleBase::CHECK_INT(ifs, nlocal);
ModuleBase::CHECK_INT(ifs, nlocal);
}// If file exist, read in data.
} // Finish reading the first part of density matrix.


#ifndef __MPI
GlobalV::ofs_running << " Read SPIN = " << is+1 << " density matrix now." << std::endl;

if(GlobalV::GAMMA_ONLY_LOCAL)
if(gamma_only_local)
{
for(int i=0; i<GlobalV::NLOCAL; ++i)
for(int i=0; i<nlocal; ++i)
{
for(int j=0; j<GlobalV::NLOCAL; ++j)
for(int j=0; j<nlocal; ++j)
{
ifs >> DM[is][i][j];
}
Expand Down Expand Up @@ -125,27 +128,27 @@ void ModuleIO::read_dm(

Parallel_Common::bcast_double(ef);

if(GlobalV::GAMMA_ONLY_LOCAL)
if(gamma_only_local)
{

double *tmp = new double[GlobalV::NLOCAL];
for(int i=0; i<GlobalV::NLOCAL; ++i)
double *tmp = new double[nlocal];
for(int i=0; i<nlocal; ++i)
{
//GlobalV::ofs_running << " i=" << i << std::endl;
ModuleBase::GlobalFunc::ZEROS(tmp, GlobalV::NLOCAL);
ModuleBase::GlobalFunc::ZEROS(tmp, nlocal);
if(GlobalV::MY_RANK==0)
{
for(int j=0; j<GlobalV::NLOCAL; ++j)
for(int j=0; j<nlocal; ++j)
{
ifs >> tmp[j];
}
}
Parallel_Common::bcast_double(tmp, GlobalV::NLOCAL);
Parallel_Common::bcast_double(tmp, nlocal);

const int mu = trace_lo[i];
if(mu >= 0)
{
for(int j=0; j<GlobalV::NLOCAL; ++j)
for(int j=0; j<nlocal; ++j)
{
const int nu = trace_lo[j];
if(nu >= 0)
Expand Down

0 comments on commit 3fc4ac2

Please sign in to comment.