Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 78 additions & 86 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
#include "esolver_ks.h"

#include <ctime>
#include <iostream>
#ifdef __MPI
#include <mpi.h>
#else
#include <chrono>
#endif
#include "module_base/timer.h"
#include "module_cell/cal_atoms_info.h"
#include "module_io/json_output/init_info.h"
Expand All @@ -15,6 +8,9 @@
#include "module_io/print_info.h"
#include "module_io/write_istate_info.h"
#include "module_parameter/parameter.h"

#include <ctime>
#include <iostream>
//--------------Temporary----------------
#include "module_base/global_variable.h"
#include "module_hamilt_lcao/module_dftu/dftu.h"
Expand Down Expand Up @@ -427,49 +423,11 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
this->niter = this->maxniter;

// 4) SCF iterations
double diag_ethr = PARAM.inp.pw_diag_thr;
this->diag_ethr = PARAM.inp.pw_diag_thr;

std::cout << " * * * * * *\n << Start SCF iteration." << std::endl;
for (int iter = 1; iter <= this->maxniter; ++iter)
{
// 5) write head
ModuleIO::write_head(GlobalV::ofs_running, istep, iter, this->basisname);

#ifdef __MPI
auto iterstart = MPI_Wtime();
#else
auto iterstart = std::chrono::system_clock::now();
#endif

if (PARAM.inp.esolver_type == "ksdft")
{
diag_ethr = hsolver::set_diagethr_ks(PARAM.inp.basis_type,
PARAM.inp.esolver_type,
PARAM.inp.calculation,
PARAM.inp.init_chg,
PARAM.inp.precision,
istep,
iter,
drho,
PARAM.inp.pw_diag_thr,
diag_ethr,
PARAM.inp.nelec);
}
else if (PARAM.inp.esolver_type == "sdft")
{
diag_ethr = hsolver::set_diagethr_sdft(PARAM.inp.basis_type,
PARAM.inp.esolver_type,
PARAM.inp.calculation,
PARAM.inp.init_chg,
istep,
iter,
drho,
PARAM.inp.pw_diag_thr,
diag_ethr,
PARAM.inp.nbands,
esolver_KS_ne);
}

// 6) initialization of SCF iterations
this->iter_init(istep, iter);

Expand Down Expand Up @@ -615,33 +573,6 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)

// 10) finish scf iterations
this->iter_finish(istep, iter);
#ifdef __MPI
double duration = (double)(MPI_Wtime() - iterstart);
#else
double duration
= (std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now() - iterstart))
.count()
/ static_cast<double>(1e6);
#endif

// 11) get mtaGGA related parameters
double dkin = 0.0; // for meta-GGA
if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
{
dkin = p_chgmix->get_dkin(pelec->charge, PARAM.inp.nelec);
}
this->pelec->print_etot(this->conv_esolver, iter, drho, dkin, duration, PARAM.inp.printe, diag_ethr);

// 12) Json, need to be moved to somewhere else
#ifdef __RAPIDJSON
// add Json of scf mag
Json::add_output_scf_mag(GlobalC::ucell.magnet.tot_magnetization,
GlobalC::ucell.magnet.abs_magnetization,
this->pelec->f_en.etot * ModuleBase::Ry_to_eV,
this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV,
drho,
duration);
#endif //__RAPIDJSON

// 13) check convergence
if (this->conv_esolver || this->oscillate_esolver)
Expand All @@ -653,12 +584,6 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
}
break;
}

// notice for restart
if (PARAM.inp.mixing_restart > 0 && iter == this->p_chgmix->mixing_restart_step - 1 && iter != PARAM.inp.scf_nmax)
{
std::cout << " SCF restart after this step!" << std::endl;
}
} // end scf iterations
std::cout << " >> Leave SCF iteration.\n * * * * * *" << std::endl;

Expand All @@ -671,6 +596,47 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
return;
};

template <typename T, typename Device>
void ESolver_KS<T, Device>::iter_init(const int istep, const int iter)
{
ModuleIO::write_head(GlobalV::ofs_running, istep, iter, this->basisname);

#ifdef __MPI
iter_time = MPI_Wtime();
#else
iter_time = std::chrono::system_clock::now();
#endif

if (PARAM.inp.esolver_type == "ksdft")
{
diag_ethr = hsolver::set_diagethr_ks(PARAM.inp.basis_type,
PARAM.inp.esolver_type,
PARAM.inp.calculation,
PARAM.inp.init_chg,
PARAM.inp.precision,
istep,
iter,
drho,
PARAM.inp.pw_diag_thr,
diag_ethr,
PARAM.inp.nelec);
}
else if (PARAM.inp.esolver_type == "sdft")
{
diag_ethr = hsolver::set_diagethr_sdft(PARAM.inp.basis_type,
PARAM.inp.esolver_type,
PARAM.inp.calculation,
PARAM.inp.init_chg,
istep,
iter,
drho,
PARAM.inp.pw_diag_thr,
diag_ethr,
PARAM.inp.nbands,
esolver_KS_ne);
}
}

template <typename T, typename Device>
void ESolver_KS<T, Device>::iter_finish(const int istep, int& iter)
{
Expand All @@ -684,6 +650,39 @@ void ESolver_KS<T, Device>::iter_finish(const int istep, int& iter)
}
this->pelec->f_en.etot_delta = this->pelec->f_en.etot - this->pelec->f_en.etot_old;
this->pelec->f_en.etot_old = this->pelec->f_en.etot;

#ifdef __MPI
double duration = (double)(MPI_Wtime() - iter_time);
#else
double duration
= (std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now() - iter_time)).count()
/ static_cast<double>(1e6);
#endif

// get mtaGGA related parameters
double dkin = 0.0; // for meta-GGA
if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
{
dkin = p_chgmix->get_dkin(pelec->charge, PARAM.inp.nelec);
}
this->pelec->print_etot(this->conv_esolver, iter, drho, dkin, duration, PARAM.inp.printe, diag_ethr);

// Json, need to be moved to somewhere else
#ifdef __RAPIDJSON
// add Json of scf mag
Json::add_output_scf_mag(GlobalC::ucell.magnet.tot_magnetization,
GlobalC::ucell.magnet.abs_magnetization,
this->pelec->f_en.etot * ModuleBase::Ry_to_eV,
this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV,
drho,
duration);
#endif //__RAPIDJSON

// notice for restart
if (PARAM.inp.mixing_restart > 0 && iter == this->p_chgmix->mixing_restart_step - 1 && iter != PARAM.inp.scf_nmax)
{
std::cout << " SCF restart after this step!" << std::endl;
}
}

//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
Expand All @@ -698,13 +697,6 @@ void ESolver_KS<T, Device>::after_scf(const int istep)
{
this->pelec->print_eigenvalue(GlobalV::ofs_running);
}
// #ifdef __RAPIDJSON
// // add Json of efermi energy converge
// Json::add_output_efermi_converge(this->pelec->eferm.ef * ModuleBase::Ry_to_eV, this->conv_esolver);
// // add nkstot,nkstot_ibz to output json
// int Jnkstot = this->pelec->klist->get_nkstot();
// Json::add_nkstot(Jnkstot);
// #endif //__RAPIDJSON
}

//------------------------------------------------------------------------------
Expand Down
104 changes: 53 additions & 51 deletions source/module_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,79 +10,81 @@
#include "module_io/cal_test.h"
#include "module_psi/psi.h"

#include <fstream>
#ifdef __MPI
#include <mpi.h>
#else
#include <chrono>
#endif
#include <cstring>
#include <fstream>
namespace ModuleESolver
{

template <typename T, typename Device = base_device::DEVICE_CPU>
class ESolver_KS : public ESolver_FP
{
public:

//! Constructor
ESolver_KS();

//! Deconstructor
virtual ~ESolver_KS();

double scf_thr; // scf density threshold
public:
//! Constructor
ESolver_KS();

double scf_ene_thr; // scf energy threshold
//! Deconstructor
virtual ~ESolver_KS();

double drho; // the difference between rho_in (before HSolver) and rho_out (After HSolver)
virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override;

int maxniter; // maximum iter steps for scf
virtual void runner(const int istep, UnitCell& cell) override;

int niter; // iter steps actually used in scf
protected:
//! Something to do before SCF iterations.
virtual void before_scf(const int istep) {};

int out_freq_elec; // frequency for output
virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09

virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override;
//! Something to do before hamilt2density function in each iter loop.
virtual void iter_init(const int istep, const int iter);

virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09
//! Something to do after hamilt2density function in each iter loop.
virtual void iter_finish(const int istep, int& iter);

virtual void runner(const int istep, UnitCell& cell) override;
// calculate electron density from a specific Hamiltonian
virtual void hamilt2density(const int istep, const int iter, const double ethr);

// calculate electron density from a specific Hamiltonian
virtual void hamilt2density(const int istep, const int iter, const double ethr);
// calculate electron states from a specific Hamiltonian
virtual void hamilt2estates(const double ethr) {};

// calculate electron states from a specific Hamiltonian
virtual void hamilt2estates(const double ethr){};
//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
virtual void after_scf(const int istep) override;

protected:
//! Something to do before SCF iterations.
virtual void before_scf(const int istep) {};
//! <Temporary> It should be replaced by a function in Hamilt Class
virtual void update_pot(const int istep, const int iter) {};

//! Something to do before hamilt2density function in each iter loop.
virtual void iter_init(const int istep, const int iter) {};
//! Hamiltonian
hamilt::Hamilt<T, Device>* p_hamilt = nullptr;

//! Something to do after hamilt2density function in each iter loop.
virtual void iter_finish(const int istep, int& iter);
ModulePW::PW_Basis_K* pw_wfc = nullptr;

//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
virtual void after_scf(const int istep) override;
Charge_Mixing* p_chgmix = nullptr;

//! <Temporary> It should be replaced by a function in Hamilt Class
virtual void update_pot(const int istep, const int iter) {};
wavefunc wf;

protected:
//! Hamiltonian
hamilt::Hamilt<T, Device>* p_hamilt = nullptr;
// wavefunction coefficients
psi::Psi<T>* psi = nullptr;

ModulePW::PW_Basis_K* pw_wfc = nullptr;

Charge_Mixing* p_chgmix = nullptr;

wavefunc wf;

// wavefunction coefficients
psi::Psi<T>* psi = nullptr;

protected:
std::string basisname; // PW or LCAO
double esolver_KS_ne = 0.0;
bool oscillate_esolver = false; // whether esolver is oscillated
};
} // end of namespace
std::string basisname; // PW or LCAO
double esolver_KS_ne = 0.0;
bool oscillate_esolver = false; // whether esolver is oscillated
#ifdef __MPI
double iter_time; // the start time of scf iteration
#else
std::chrono::system_clock::time_point iter_time;
#endif
double diag_ethr; // the threshold for diagonalization
double scf_thr; // scf density threshold
double scf_ene_thr; // scf energy threshold
double drho; // the difference between rho_in (before HSolver) and rho_out (After HSolver)
int maxniter; // maximum iter steps for scf
int niter; // iter steps actually used in scf
int out_freq_elec; // frequency for output
};
} // namespace ModuleESolver
#endif
3 changes: 3 additions & 0 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,9 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(const int istep, const int iter)
{
ModuleBase::TITLE("ESolver_KS_LCAO", "iter_init");

// call iter_init() of ESolver_KS
ESolver_KS<TK>::iter_init(istep, iter);

if (iter == 1)
{
this->p_chgmix->init_mixing(); // init mixing
Expand Down
3 changes: 3 additions & 0 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ void ESolver_KS_PW<T, Device>::before_scf(const int istep)
template <typename T, typename Device>
void ESolver_KS_PW<T, Device>::iter_init(const int istep, const int iter)
{
// call iter_init() of ESolver_KS
ESolver_KS<T, Device>::iter_init(istep, iter);

if (iter == 1)
{
this->p_chgmix->init_mixing();
Expand Down
Loading