Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 61 additions & 7 deletions source/src_io/istate_envelope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ IState_Envelope::~IState_Envelope()
{}


void IState_Envelope::begin(void)
void IState_Envelope::begin(Local_Orbital_wfc &lowf, Gint_Gamma &gg)
{
ModuleBase::TITLE("IState_Envelope","begin");

Expand Down Expand Up @@ -64,7 +64,30 @@ void IState_Envelope::begin(void)
}
}

for(int ib=0; ib<GlobalV::NBANDS; ib++)
//allocate grid wavefunction for gamma_only
std::vector<double**> wfc_gamma_grid(GlobalV::NSPIN);
for(int is=0; is<GlobalV::NSPIN; ++is)
{
wfc_gamma_grid[is] = new double* [GlobalV::NBANDS];
for (int ib = 0;ib < GlobalV::NBANDS; ++ib)
wfc_gamma_grid[is][ib] = new double[GlobalC::GridT.lgd];
}

const Parallel_Orbitals* pv = lowf.ParaV;

//calculate maxnloc for bcasting 2d-wfc
int nprocs, myid;
MPI_Comm_size(pv->comm_2D, &nprocs);
MPI_Comm_rank(pv->comm_2D, &myid);

long maxnloc; // maximum number of elements in local matrix
MPI_Reduce(&pv->nloc_wfc, &maxnloc, 1, MPI_LONG, MPI_MAX, 0, pv->comm_2D);
MPI_Bcast(&maxnloc, 1, MPI_LONG, 0, pv->comm_2D);
const int inc = 1;
int naroc[2]; // maximum number of row or column
double* work = new double[maxnloc]; // work/buffer matrix

for (int ib = 0; ib < GlobalV::NBANDS; ib++)
{
if(bands_picked[ib])
{
Expand All @@ -79,9 +102,33 @@ void IState_Envelope::begin(void)
// we need to fix this function in near future.
// -- mohan add 2021-02-09
//---------------------------------------------------------
ModuleBase::WARNING_QUIT("IState_Charge::idmatrix","need to update LOWF.WFC_GAMMA");

//GlobalC::UHM.GG.cal_env( GlobalC::LOWF.WFC_GAMMA[is][ib], GlobalC::CHR.rho[is] );
//ModuleBase::WARNING_QUIT("IState_Charge::idmatrix","need to update LOWF.WFC_GAMMA");

//convert 2d `wfc_gamma` to grid `wfc_gamma_grid`
int info;
for(int iprow=0; iprow<pv->dim0; ++iprow)
{
for(int ipcol=0; ipcol<pv->dim1; ++ipcol)
{
const int coord[2]={iprow, ipcol};
int src_rank;
MPI_Cart_rank(pv->comm_2D, coord, &src_rank);
if(myid==src_rank)
{
BlasConnector::copy(pv->nloc_wfc, lowf.wfc_gamma[is].c, inc, work, inc);
naroc[0]=pv->nrow;
naroc[1]=pv->ncol_bands;
}
info=MPI_Bcast(naroc, 2, MPI_INT, src_rank, pv->comm_2D);
info=MPI_Bcast(work, maxnloc, MPI_DOUBLE, src_rank, pv->comm_2D);

info=lowf.q2WFC(myid, naroc, pv->nb,
pv->dim0, pv->dim1, iprow, ipcol, pv->loc_size,
work, wfc_gamma_grid[is]);
}//loop ipcol
}//loop iprow

gg.cal_env( wfc_gamma_grid[is][ib], GlobalC::CHR.rho[is] );


GlobalC::CHR.save_rho_before_sum_band(); //xiaohui add 2014-12-09
Expand All @@ -94,8 +141,15 @@ void IState_Envelope::begin(void)
}
}

delete[] bands_picked;
return;
delete[] work;
delete[] bands_picked;
for(int is=0; is<GlobalV::NSPIN; ++is)
{
for (int ib = 0;ib < GlobalV::NBANDS; ++ib)
delete[] wfc_gamma_grid[is][ib];
delete[] wfc_gamma_grid[is];
}
return;
}


4 changes: 3 additions & 1 deletion source/src_io/istate_envelope.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#ifndef ISTATE_ENVELOPE_H
#define ISTATE_ENVELOPE_H
#include "src_lcao/local_orbital_wfc.h"
#include "src_lcao/gint_gamma.h"

class IState_Envelope
{
public:
IState_Envelope();
~IState_Envelope();

void begin();
void begin(Local_Orbital_wfc &lowf, Gint_Gamma &gg);

private:
bool *bands_picked;
Expand Down
21 changes: 9 additions & 12 deletions source/src_io/wf_local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ inline int CTOT2q_c(
}

// be called in local_orbital_wfc::allocate_k
int WF_Local::read_lowf_complex(std::complex<double>** c, const int& ik,
int WF_Local::read_lowf_complex(std::complex<double>** ctot, const int& ik,
Local_Orbital_wfc &lowf)
{
ModuleBase::TITLE("WF_Local","read_lowf_complex");
ModuleBase::timer::tick("WF_Local","read_lowf_complex");

std::complex<double> **ctot;

lowf.wfc_k[ik].create(lowf.ParaV->ncol_bands, lowf.ParaV->nrow);
std::stringstream ss;
// read wave functions
// write is in ../src_pdiag/pdiag_basic.cpp
Expand Down Expand Up @@ -212,14 +211,12 @@ int WF_Local::read_lowf_complex(std::complex<double>** c, const int& ik,
return 0;
}

int WF_Local::read_lowf(double** c, const int& is,
int WF_Local::read_lowf(double** ctot, const int& is,
Local_Orbital_wfc &lowf)
{
ModuleBase::TITLE("WF_Local","read_lowf");
ModuleBase::timer::tick("WF_Local","read_lowf");

double **ctot;

ModuleBase::timer::tick("WF_Local", "read_lowf");

std::stringstream ss;
if(GlobalV::GAMMA_ONLY_LOCAL)
{
Expand Down Expand Up @@ -498,7 +495,7 @@ void WF_Local::distri_lowf_complex_new(std::complex<double>** ctot, const int& i
//1. alloc work array; set some parameters

long maxnloc; // maximum number of elements in local matrix
MPI_Reduce(&lowf.ParaV->nloc, &maxnloc, 1, MPI_LONG, MPI_MAX, 0, lowf.ParaV->comm_2D);
MPI_Reduce(&lowf.ParaV->nloc_wfc, &maxnloc, 1, MPI_LONG, MPI_MAX, 0, lowf.ParaV->comm_2D);
MPI_Bcast(&maxnloc, 1, MPI_LONG, 0, lowf.ParaV->comm_2D);
//reduce and bcast could be replaced by allreduce

Expand Down Expand Up @@ -533,7 +530,7 @@ void WF_Local::distri_lowf_complex_new(std::complex<double>** ctot, const int& i
if(myid==src_rank)
{
naroc[0]=lowf.ParaV->nrow;
naroc[1]=lowf.ParaV->ncol;
naroc[1]=lowf.ParaV->ncol_bands;
}
info=MPI_Bcast(naroc, 2, MPI_INT, src_rank, lowf.ParaV->comm_2D);

Expand All @@ -547,10 +544,10 @@ void WF_Local::distri_lowf_complex_new(std::complex<double>** ctot, const int& i
//}
//ofs_running << std::endl;
//2.3 copy from work to wfc_k
const int inc=1;
const int inc = 1;
if(myid==src_rank)
{
BlasConnector::copy(lowf.ParaV->nloc, work, inc, lowf.wfc_k.at(ik).c, inc);
BlasConnector::copy(lowf.ParaV->nloc_wfc, work, inc, lowf.wfc_k.at(ik).c, inc);
}
}//loop ipcol
}//loop iprow
Expand Down
4 changes: 2 additions & 2 deletions source/src_io/wf_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ namespace WF_Local
void distri_lowf_complex_new(std::complex<double>** ctot, const int& ik,
Local_Orbital_wfc &lowf);

int read_lowf(double** c, const int& is,
int read_lowf(double** ctot, const int& is,
Local_Orbital_wfc &lowf);

int read_lowf_complex(std::complex<double>** c, const int& ik,
int read_lowf_complex(std::complex<double>** ctot, const int& ik,
Local_Orbital_wfc &lowf);
}

Expand Down
48 changes: 0 additions & 48 deletions source/src_lcao/DM_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,58 +49,10 @@ void Local_Orbital_Charge::allocate_DM_k(void)

// Peize Lin test 2019-01-16
this->init_dm_2d();
if(GlobalC::wf.start_wfc=="file")
{
this->kpt_file(GlobalC::GridT, *this->LOWF);
}

return;
}

void Local_Orbital_Charge::kpt_file(const Grid_Technique& gt,
Local_Orbital_wfc &lowf)
{
ModuleBase::TITLE("Local_Orbital_Charge","kpt_file");

int error;
std::cout << " Read in wave functions files: " << GlobalC::kv.nkstot << std::endl;

std::complex<double> **ctot;

for(int ik=0; ik<GlobalC::kv.nkstot; ++ik)
{

lowf.wfc_k[ik].create(this->ParaV->ncol_bands, this->ParaV->nrow);
lowf.wfc_k[ik].zero_out();

GlobalV::ofs_running << " Read in wave functions " << ik + 1 << std::endl;
error = WF_Local::read_lowf_complex( ctot , ik, lowf);

#ifdef __MPI
Parallel_Common::bcast_int(error);
#endif
GlobalV::ofs_running << " Error=" << error << std::endl;
if(error==1)
{
ModuleBase::WARNING_QUIT("Local_Orbital_wfc","Can't find the wave function file: LOWF.dat");
}
else if(error==2)
{
ModuleBase::WARNING_QUIT("Local_Orbital_wfc","In wave function file, band number doesn't match");
}
else if(error==3)
{
ModuleBase::WARNING_QUIT("Local_Orbital_wfc","In wave function file, nlocal doesn't match");
}
else if(error==4)
{
ModuleBase::WARNING_QUIT("Local_Orbital_wfc","In k-dependent wave function file, k point is not correct");
}

}//loop ispin
}


#include "record_adj.h"
inline void cal_DM_ATOM(
const Grid_Technique &gt,
Expand Down
2 changes: 1 addition & 1 deletion source/src_lcao/LOOP_elec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ void LOOP_elec::solver(const int& istep,
else if (GlobalV::CALCULATION=="ienvelope")
{
IState_Envelope IEP;
IEP.begin();
IEP.begin(lowf, this->UHM->GG);
}
else
{
Expand Down
6 changes: 3 additions & 3 deletions source/src_lcao/local_orbital_charge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ void Local_Orbital_Charge::allocate_dm_wfc(const Grid_Technique& gt,
Local_Orbital_wfc &lowf)
{
ModuleBase::TITLE("Local_Orbital_Charge", "allocate_dm_wfc");
if(GlobalV::GAMMA_ONLY_LOCAL)

this->LOWF = &lowf;
if (GlobalV::GAMMA_ONLY_LOCAL)
{
// here we reset the density matrix dimension.
this->allocate_gamma(gt);
Expand All @@ -91,8 +93,6 @@ void Local_Orbital_Charge::allocate_dm_wfc(const Grid_Technique& gt,
lowf.allocate_k(gt, lowf);
this->allocate_DM_k();
}

this->LOWF = &lowf;

return;
}
Expand Down
3 changes: 0 additions & 3 deletions source/src_lcao/local_orbital_charge.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ class Local_Orbital_Charge
// in DM_k.cpp
//-----------------
void allocate_DM_k(void);

void kpt_file(const Grid_Technique& gt,
Local_Orbital_wfc &lowf);

// liaochen modify on 2010-3-23
// change its state from private to public
Expand Down
69 changes: 49 additions & 20 deletions source/src_lcao/local_orbital_wfc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,28 +101,28 @@ void Local_Orbital_wfc::allocate_k(const Grid_Technique& gt,
for(int ik=0; ik<GlobalC::kv.nkstot; ++ik)
{
GlobalV::ofs_running << " Read in wave functions " << ik + 1 << std::endl;
error = WF_Local::read_lowf_complex( this->wfc_k_grid[ik], ik, lowf);
}
error = WF_Local::read_lowf_complex(this->wfc_k_grid[ik], ik, lowf);
#ifdef __MPI
Parallel_Common::bcast_int(error);
Parallel_Common::bcast_int(error);
#endif
GlobalV::ofs_running << " Error=" << error << std::endl;
if(error==1)
{
ModuleBase::WARNING_QUIT("Local_Orbital_wfc","Can't find the wave function file: LOWF.dat");
}
else if(error==2)
{
ModuleBase::WARNING_QUIT("Local_Orbital_wfc","In wave function file, band number doesn't match");
}
else if(error==3)
{
ModuleBase::WARNING_QUIT("Local_Orbital_wfc","In wave function file, nlocal doesn't match");
}
else if(error==4)
{
ModuleBase::WARNING_QUIT("Local_Orbital_wfc","In k-dependent wave function file, k point is not correct");
}
GlobalV::ofs_running << " Error=" << error << std::endl;
if(error==1)
{
ModuleBase::WARNING_QUIT("Local_Orbital_wfc","Can't find the wave function file: LOWF.dat");
}
else if(error==2)
{
ModuleBase::WARNING_QUIT("Local_Orbital_wfc","In wave function file, band number doesn't match");
}
else if(error==3)
{
ModuleBase::WARNING_QUIT("Local_Orbital_wfc","In wave function file, nlocal doesn't match");
}
else if(error==4)
{
ModuleBase::WARNING_QUIT("Local_Orbital_wfc","In k-dependent wave function file, k point is not correct");
}
}
}
else
{
Expand Down Expand Up @@ -172,6 +172,35 @@ int Local_Orbital_wfc::q2CTOT(
return 0;
}

int Local_Orbital_wfc::q2WFC(
int myid,
int naroc[2],
int nb,
int dim0,
int dim1,
int iprow,
int ipcol,
int loc_size,
double* work,
double** WFC)
{
ModuleBase::TITLE(" Local_Orbital_wfc","q2WFC");
for (int j = 0; j < naroc[1]; ++j)
{
int igcol=globalIndex(j, nb, dim1, ipcol);
if(igcol>=GlobalV::NBANDS) continue;
for(int i=0; i<naroc[0]; ++i)
{
int igrow = globalIndex(i, nb, dim0, iprow);
int mu_local = GlobalC::GridT.trace_lo[igrow];
if (mu_local >= 0 )
{
WFC[igcol][mu_local]=work[j*naroc[0]+i];
}
}
}
return 0;
}

int Local_Orbital_wfc::q2WFC_complex(
int naroc[2],
Expand Down
Loading