Skip to content

Commit

Permalink
Feature : loading equivariant deepks model (#4137)
Browse files Browse the repository at this point in the history
* Feature : loading equivariant deepks model

* slight modification of the python script

* removed obsolete code

* initialize variable

* udpate ut

* add force calculation

* update header

* fix bug loading npy file

* add check in cal_gedm_equiv for execution of python script

* fix bug in calculation of gedm

---------

Co-authored-by: wenfei-li <liwenfei@gmail.com>
Co-authored-by: Mohan Chen <mohan.chen.chen.mohan@gmail.com>
  • Loading branch information
3 people committed May 16, 2024
1 parent 24a27b9 commit 99f4592
Show file tree
Hide file tree
Showing 15 changed files with 319 additions and 1,151 deletions.
2 changes: 0 additions & 2 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ void Force_LCAO_gamma::ftable_gamma(const bool isforce,
GlobalC::ld.check_projected_dm();
GlobalC::ld.check_descriptor(GlobalC::ucell);
GlobalC::ld.check_gedm();
GlobalC::ld.add_v_delta(GlobalC::ucell, GlobalC::ORB, GlobalC::GridD);
GlobalC::ld.check_v_delta();

GlobalC::ld.cal_e_delta_band(dm_gamma);
std::ofstream ofs("E_delta_bands.dat");
Expand Down
19 changes: 0 additions & 19 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,25 +143,6 @@ void Force_LCAO_k::ftable_k(const bool isforce,
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
}
#endif
/*if (GlobalV::deepks_out_unittest)
{
GlobalC::ld.print_dm_k(kv.nks, dm_k);
GlobalC::ld.check_projected_dm();
GlobalC::ld.check_descriptor(GlobalC::ucell);
GlobalC::ld.check_gedm();
GlobalC::ld.add_v_delta_k(GlobalC::ucell, GlobalC::ORB, GlobalC::GridD, pv->nnr);
GlobalC::ld.check_v_delta_k(pv->nnr);
for (int ik = 0; ik < kv.nks; ik++)
{
LM->folding_fixedH(ik, kv.kvec_d);
}
GlobalC::ld.cal_e_delta_band_k(dm_k, kv.nks);
std::ofstream ofs("E_delta_bands.dat");
ofs << std::setprecision(10) << GlobalC::ld.e_delta_band;
std::ofstream ofs1("E_delta.dat");
ofs1 << std::setprecision(10) << GlobalC::ld.E_delta;
GlobalC::ld.check_f_delta(GlobalC::ucell.nat, svnl_dalpha);
}*/
}
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,6 @@ void DeePKS<OperatorLCAO<double, double>>::contributeHR()
GlobalC::GridD);
GlobalC::ld.cal_descriptor(this->ucell->nat);
GlobalC::ld.cal_gedm(this->ucell->nat);
//GlobalC::ld.add_v_delta(*this->ucell,
// GlobalC::ORB,
// GlobalC::GridD);
// recalculate the H_V_delta
this->H_V_delta->set_zero();
this->calculate_HR();
Expand Down Expand Up @@ -193,9 +190,6 @@ void DeePKS<OperatorLCAO<std::complex<double>, double>>::contributeHR()
GlobalC::ld.cal_descriptor(this->ucell->nat);
// calculate dE/dD
GlobalC::ld.cal_gedm(this->ucell->nat);

// calculate H_V_deltaR from saved <alpha(0)|psi(R)>
//GlobalC::ld.add_v_delta_k(*this->ucell, GlobalC::ORB, GlobalC::GridD, this->LM->ParaV->nnr);

// recalculate the H_V_delta
if(this->H_V_delta == nullptr)
Expand Down Expand Up @@ -233,10 +227,6 @@ void DeePKS<OperatorLCAO<std::complex<double>, std::complex<double>>>::contribut
// calculate dE/dD
GlobalC::ld.cal_gedm(this->ucell->nat);

// calculate H_V_deltaR from saved <alpha(0)|psi(R)>
//GlobalC::ld
// .add_v_delta_k(*this->ucell, GlobalC::ORB, GlobalC::GridD, this->LM->ParaV->nnr);

// recalculate the H_V_delta
if(this->H_V_delta == nullptr)
{
Expand Down Expand Up @@ -347,25 +337,46 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::calculate_HR()
std::vector<int> trace_alpha_row;
std::vector<int> trace_alpha_col;
std::vector<double> gedms;
int ib=0;
for (int L0 = 0; L0 <= orb.Alpha[0].getLmax();++L0)
if(!GlobalC::ld.get_if_equiv())
{
for (int N0 = 0;N0 < orb.Alpha[0].getNchi(L0);++N0)
int ib=0;
for (int L0 = 0; L0 <= orb.Alpha[0].getLmax();++L0)
{
const int inl = GlobalC::ld.get_inl(T0, I0, L0, N0);
const double* pgedm = GlobalC::ld.get_gedms(inl);
const int nm = 2*L0+1;

for (int m1=0; m1<nm; ++m1) // m1 = 1 for s, 3 for p, 5 for d
for (int N0 = 0;N0 < orb.Alpha[0].getNchi(L0);++N0)
{
for (int m2=0; m2<nm; ++m2) // m1 = 1 for s, 3 for p, 5 for d
const int inl = GlobalC::ld.get_inl(T0, I0, L0, N0);
const double* pgedm = GlobalC::ld.get_gedms(inl);
const int nm = 2*L0+1;

for (int m1=0; m1<nm; ++m1) // m1 = 1 for s, 3 for p, 5 for d
{
trace_alpha_row.push_back(ib+m1);
trace_alpha_col.push_back(ib+m2);
gedms.push_back(pgedm[m1*nm+m2]);
for (int m2=0; m2<nm; ++m2) // m1 = 1 for s, 3 for p, 5 for d
{
trace_alpha_row.push_back(ib+m1);
trace_alpha_col.push_back(ib+m2);
gedms.push_back(pgedm[m1*nm+m2]);
}
}
ib+=nm;
}
}
}
else
{
const double * pgedm = GlobalC::ld.get_gedms(iat0);
int nproj = 0;
for(int il = 0; il < GlobalC::ld.get_lmaxd() + 1; il++)
{
nproj += (2 * il + 1) * orb.Alpha[0].getNchi(il);
}
for(int iproj = 0; iproj < nproj; iproj ++)
{
for(int jproj = 0; jproj < nproj; jproj ++)
{
trace_alpha_row.push_back(iproj);
trace_alpha_col.push_back(jproj);
gedms.push_back(pgedm[iproj*nproj+jproj]);
}
ib+=nm;
}
}
const int trace_alpha_size = trace_alpha_row.size();
Expand Down
47 changes: 38 additions & 9 deletions source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,29 @@ void LCAO_Deepks::init_gdmx(const int nat)
this->gdmx = new double** [nat];
this->gdmy = new double** [nat];
this->gdmz = new double** [nat];
int pdm_size = 0;
if(!if_equiv)
{
pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1);
}
else
{
pdm_size = this -> des_per_atom;
}

for (int iat = 0;iat < nat;iat++)
{
this->gdmx[iat] = new double* [inlmax];
this->gdmy[iat] = new double* [inlmax];
this->gdmz[iat] = new double* [inlmax];
for (int inl = 0;inl < inlmax;inl++)
{
this->gdmx[iat][inl] = new double [(2 * lmaxd + 1) * (2 * lmaxd + 1)];
this->gdmy[iat][inl] = new double [(2 * lmaxd + 1) * (2 * lmaxd + 1)];
this->gdmz[iat][inl] = new double[(2 * lmaxd + 1) * (2 * lmaxd + 1)];
ModuleBase::GlobalFunc::ZEROS(gdmx[iat][inl], (2 * lmaxd + 1) * (2 * lmaxd + 1));
ModuleBase::GlobalFunc::ZEROS(gdmy[iat][inl], (2 * lmaxd + 1) * (2 * lmaxd + 1));
ModuleBase::GlobalFunc::ZEROS(gdmz[iat][inl], (2 * lmaxd + 1) * (2 * lmaxd + 1));
this->gdmx[iat][inl] = new double [pdm_size];
this->gdmy[iat][inl] = new double [pdm_size];
this->gdmz[iat][inl] = new double[pdm_size];
ModuleBase::GlobalFunc::ZEROS(gdmx[iat][inl], pdm_size);
ModuleBase::GlobalFunc::ZEROS(gdmy[iat][inl], pdm_size);
ModuleBase::GlobalFunc::ZEROS(gdmz[iat][inl], pdm_size);
}
}
this->nat_gdm = nat;
Expand Down Expand Up @@ -258,13 +268,23 @@ void LCAO_Deepks::init_gdmepsl()
{
this->gdm_epsl = new double** [6];

int pdm_size = 0;
if(!if_equiv)
{
pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1);
}
else
{
pdm_size = this -> des_per_atom;
}

for (int ipol = 0;ipol < 6;ipol++)
{
this->gdm_epsl[ipol] = new double* [inlmax];
for (int inl = 0;inl < inlmax;inl++)
{
this->gdm_epsl[ipol][inl] = new double [(2 * lmaxd + 1) * (2 * lmaxd + 1)];
ModuleBase::GlobalFunc::ZEROS(gdm_epsl[ipol][inl], (2 * lmaxd + 1) * (2 * lmaxd + 1));
this->gdm_epsl[ipol][inl] = new double [pdm_size];
ModuleBase::GlobalFunc::ZEROS(gdm_epsl[ipol][inl], pdm_size);
}
}
return;
Expand Down Expand Up @@ -307,7 +327,16 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
}

//init gedm**
const int pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1);
int pdm_size = 0;
if(!if_equiv)
{
pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1);
}
else
{
pdm_size = this -> des_per_atom;
}

this->gedm = new double* [this->inlmax];
for (int inl = 0;inl < this->inlmax;inl++)
{
Expand Down
22 changes: 6 additions & 16 deletions source/module_hamilt_lcao/module_deepks/LCAO_deepks.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class LCAO_Deepks
// temporary add two getters for inl_index and gedm
int get_inl(const int& T0, const int& I0, const int& L0, const int& N0) { return inl_index[T0](I0, L0, N0); }
const double* get_gedms(const int& inl){ return gedm[inl]; }

bool get_if_equiv(){return if_equiv;}
int get_lmaxd(){return lmaxd;}
//-------------------
// private variables
//-------------------
Expand Down Expand Up @@ -332,27 +335,11 @@ class LCAO_Deepks
//tr (rho * V_delta)

//Four subroutines are contained in the file:
//1. add_v_delta : adds deepks contribution to hamiltonian, for gamma only
//2. add_v_delta_k : counterpart of 1, for multi-k
//3. check_v_delta : prints H_V_delta for checking
//4. check_v_delta_k : prints H_V_deltaR for checking
//5. cal_e_delta_band : calculates e_delta_bands for gamma only
//6. cal_e_delta_band_k : counterpart of 4, for multi-k

public:

///add dV to the Hamiltonian matrix
void add_v_delta(const UnitCell &ucell,
const LCAO_Orbitals &orb,
Grid_Driver& GridD);
void add_v_delta_k(const UnitCell &ucell,
const LCAO_Orbitals &orb,
Grid_Driver& GridD,
const int nnr_in);

void check_v_delta();
void check_v_delta_k(const int nnr);

///calculate tr(\rho V_delta)
//void cal_e_delta_band(const std::vector<ModuleBase::matrix>& dm/**<[in] density matrix*/);
void cal_e_delta_band(const std::vector<std::vector<double>>& dm/**<[in] density matrix*/);
Expand Down Expand Up @@ -477,6 +464,7 @@ class LCAO_Deepks
///calculate partial of energy correction to descriptors
void cal_gedm(const int nat);
void check_gedm(void);
void cal_gedm_equiv(const int nat);

//calculates orbital_precalc
void cal_orbital_precalc(const std::vector<std::vector<ModuleBase::matrix>>& dm_hl/**<[in] density matrix*/,
Expand Down Expand Up @@ -554,6 +542,8 @@ class LCAO_Deepks
void save_npy_o(const ModuleBase::matrix &bandgap/**<[in] \f$E_{base}\f$ or \f$E_{tot}\f$, in Ry*/, const std::string &o_file, const int nks);
void save_npy_orbital_precalc(const int nat, const int nks);

void load_npy_gedm(const int nat);

//-------------------
// LCAO_deepks_mpi.cpp
//-------------------
Expand Down

0 comments on commit 99f4592

Please sign in to comment.