Skip to content

Commit

Permalink
Refactor: move vdw classes to module_vdw and clean up (#1448)
Browse files Browse the repository at this point in the history
* vdw: move source files from src_pw to module_vdw and clean up

* vdw: combine calculate and get

* vdw: downgrade to c++11 and clean up

* vdw: replace old vdw classes with new ones

* vdw: remove old files

* vdw: fix deepks test compile
  • Loading branch information
xinyangd committed Oct 28, 2022
1 parent 42ca7e9 commit 64a23b8
Show file tree
Hide file tree
Showing 36 changed files with 2,212 additions and 2,179 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ target_link_libraries(${ABACUS_BIN_NAME}
hamilt
psi
esolver
vdw
)

if(ENABLE_LCAO)
Expand Down
1 change: 1 addition & 0 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_subdirectory(src_pdiag)
add_subdirectory(src_pw)
add_subdirectory(src_ri)
add_subdirectory(module_rpa)
add_subdirectory(module_vdw)

add_library(
driver
Expand Down
2 changes: 1 addition & 1 deletion source/module_base/element_name.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace ModuleBase
{

const std::vector<std::string> element_name = {
static const std::vector<std::string> element_name = {
"H" ,
"He" ,
"Li" ,
Expand Down
2 changes: 1 addition & 1 deletion source/module_deepks/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ target_link_libraries(
test_deepks
base cell symmetry md surchem xc_
neighbor orb io relax gint lcao parallel mrrr pdiag pw ri driver esolver hsolver psi elecstate hamilt planewave
pthread genelpa
pthread genelpa vdw
deepks rpa
${ABACUS_LINK_LIBRARIES}
)
Expand Down
28 changes: 4 additions & 24 deletions source/module_esolver/esolver_ks_lcao_elec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
#include "../src_ri/exx_opt_orb.h"
#include "../src_io/berryphase.h"
#include "../src_io/to_wannier90.h"
#include "../src_pw/vdwd2.h"
#include "../src_pw/vdwd3.h"
#include "../module_base/timer.h"
#ifdef __DEEPKS
#include "../module_deepks/LCAO_deepks.h"
#endif
#include "../src_pw/H_Ewald_pw.h"
#include "module_vdw/vdw.h"

namespace ModuleESolver
{
Expand Down Expand Up @@ -279,29 +278,10 @@ namespace ModuleESolver
//----------------------------------------------------------
// about vdw, jiyy add vdwd3 and linpz add vdwd2
//----------------------------------------------------------
if (INPUT.vdw_method == "d2")
auto vdw_solver = vdw::make_vdw(GlobalC::ucell, INPUT);
if (vdw_solver != nullptr)
{
// setup vdwd2 parameters
GlobalC::vdwd2_para.initial_parameters(INPUT);
GlobalC::vdwd2_para.initset(GlobalC::ucell);
}
if (INPUT.vdw_method == "d3_0" || INPUT.vdw_method == "d3_bj")
{
GlobalC::vdwd3_para.initial_parameters(INPUT);
}
// Peize Lin add 2014.04.04, update 2021.03.09
if (GlobalC::vdwd2_para.flag_vdwd2)
{
Vdwd2 vdwd2(GlobalC::ucell, GlobalC::vdwd2_para);
vdwd2.cal_energy();
GlobalC::en.evdw = vdwd2.get_energy();
}
// jiyy add 2019-05-18, update 2021.05.02
else if (GlobalC::vdwd3_para.flag_vdwd3)
{
Vdwd3 vdwd3(GlobalC::ucell, GlobalC::vdwd3_para);
vdwd3.cal_energy();
GlobalC::en.evdw = vdwd3.get_energy();
GlobalC::en.evdw = vdw_solver->get_energy();
}

this->beforesolver(istep);
Expand Down
30 changes: 5 additions & 25 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
#include "../src_pw/global.h"
#include "../module_base/global_function.h"
#include "../module_symmetry/symmetry.h"
#include "../src_pw/vdwd2.h"
#include "../src_pw/vdwd3.h"
#include "../src_pw/vdwd2_parameters.h"
#include "../src_pw/vdwd3_parameters.h"
#include "../src_pw/pw_complement.h"
#include "../src_pw/structure_factor.h"
#include "../src_pw/symmetry_rho.h"
Expand All @@ -30,6 +26,7 @@
#include "module_elecstate/elecstate_pw.h"
#include "module_hamilt/hamilt_pw.h"
#include "module_hsolver/diago_iter_assist.h"
#include "module_vdw/vdw.h"

#include "src_io/write_wfc_realspace.h"
#include "src_io/winput.h"
Expand Down Expand Up @@ -196,29 +193,12 @@ namespace ModuleESolver

//----------------------------------------------------------
// about vdw, jiyy add vdwd3 and linpz add vdwd2
//----------------------------------------------------------
if(INPUT.vdw_method=="d2")
{
// setup vdwd2 parameters
GlobalC::vdwd2_para.initial_parameters(INPUT);
GlobalC::vdwd2_para.initset(GlobalC::ucell);
}
if(INPUT.vdw_method=="d3_0" || INPUT.vdw_method=="d3_bj")
//----------------------------------------------------------
auto vdw_solver = vdw::make_vdw(GlobalC::ucell, INPUT);
if (vdw_solver != nullptr)
{
GlobalC::vdwd3_para.initial_parameters(INPUT);
GlobalC::en.evdw = vdw_solver->get_energy();
}
if(GlobalC::vdwd2_para.flag_vdwd2) //Peize Lin add 2014-04-03, update 2021-03-09
{
Vdwd2 vdwd2(GlobalC::ucell,GlobalC::vdwd2_para);
vdwd2.cal_energy();
GlobalC::en.evdw = vdwd2.get_energy();
}
if(GlobalC::vdwd3_para.flag_vdwd3) //jiyy add 2019-05-18, update 2021-05-02
{
Vdwd3 vdwd3(GlobalC::ucell,GlobalC::vdwd3_para);
vdwd3.cal_energy();
GlobalC::en.evdw = vdwd3.get_energy();
}

//calculate ewald energy
if(!GlobalV::test_skip_ewald)
Expand Down
11 changes: 8 additions & 3 deletions source/module_vdw/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
add_library(
base
OBJECT
xxx.cpp
vdw
OBJECT
vdwd2_parameters.cpp
vdwd3_parameters_tab.cpp
vdwd3_parameters.cpp
vdwd2.cpp
vdwd3.cpp
vdw.cpp
)

if(ENABLE_COVERAGE)
Expand Down
30 changes: 30 additions & 0 deletions source/module_vdw/vdw.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@

#include "vdw.h"
#include "vdwd2.h"
#include "vdwd3.h"

namespace vdw
{

std::unique_ptr<Vdw> make_vdw(const UnitCell_pseudo &ucell, const Input &input)
{
if (INPUT.vdw_method == "d2")
{
std::unique_ptr<Vdwd2> vdw_ptr = make_unique<Vdwd2>(ucell);
vdw_ptr->parameter().initial_parameters(input);
vdw_ptr->parameter().initset(ucell);
return vdw_ptr;
}
else if (INPUT.vdw_method == "d3_0" || INPUT.vdw_method == "d3_bj")
{
std::unique_ptr<Vdwd3> vdw_ptr = make_unique<Vdwd3>(ucell);
vdw_ptr->parameter().initial_parameters(input);
return vdw_ptr;
}
else
{
return nullptr;
}
}

} // namespace vdw
54 changes: 54 additions & 0 deletions source/module_vdw/vdw.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#ifndef VDW_H
#define VDW_H

#include <vector>
#include "module_cell/unitcell_pseudo.h"
#include "module_vdw/vdw_parameters.h"
#include "module_vdw/vdwd2_parameters.h"
#include "module_vdw/vdwd3_parameters.h"

namespace vdw
{

template<typename T, typename... Args>
std::unique_ptr<T> make_unique(Args &&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}

class Vdw
{
public:
Vdw(const UnitCell_pseudo &unit_in) : ucell_(unit_in) {};

virtual ~Vdw(){};

inline double get_energy(bool cal=true) {
if (cal) { cal_energy(); }
return energy_;
}
inline const std::vector<ModuleBase::Vector3<double>> &get_force(bool cal=true) {
if (cal) { cal_force(); }
return force_;
}
inline const ModuleBase::Matrix3 &get_stress(bool cal=true) {
if (cal) { cal_stress(); }
return stress_;
}

protected:
const UnitCell_pseudo &ucell_;

double energy_ = 0;
std::vector<ModuleBase::Vector3<double>> force_;
ModuleBase::Matrix3 stress_;

virtual void cal_energy() { throw std::runtime_error("No cal_energy method in base Vdw class"); }
virtual void cal_force() { throw std::runtime_error("No cal_energy method in base Vdw class"); }
virtual void cal_stress() { throw std::runtime_error("No cal_energy method in base Vdw class"); }
};

std::unique_ptr<Vdw> make_vdw(const UnitCell_pseudo &ucell, const Input &input);

} // namespace vdw

#endif // VDW_H
27 changes: 27 additions & 0 deletions source/module_vdw/vdw_parameters.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef VDW_PARAMETERS_H
#define VDW_PARAMETERS_H

#include "module_base/vector3.h"

#include <string>

namespace vdw
{

class VdwParameters
{
public:
VdwParameters() = default;
virtual ~VdwParameters() = default;

inline const std::string &model() const { return model_; }
inline const ModuleBase::Vector3<int> &period() const { return period_; };

protected:
std::string model_;
ModuleBase::Vector3<int> period_;
};

} // namespace vdw

#endif // VDW_PARAMETERS_H
88 changes: 88 additions & 0 deletions source/module_vdw/vdwd2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//==========================================================
// AUTHOR : Peize Lin
// DATE : 2014-04-25
// UPDATE : 2019-04-26
//==========================================================

#include "module_vdw/vdwd2.h"

namespace vdw
{

void Vdwd2::cal_energy()
{
ModuleBase::TITLE("Vdwd2", "energy");
para_.initset(ucell_);
energy_ = 0;

auto energy = [&](double r,
double R0_sum,
double C6_product,
double r_sqr,
int,
int,
const ModuleBase::Vector3<double> &,
const ModuleBase::Vector3<double> &) {
const double tmp_damp_recip = 1 + exp(-para_.damping() * (r / R0_sum - 1));
energy_ -= C6_product / pow(r_sqr, 3) / tmp_damp_recip / 2;
};
index_loops(energy);
energy_ *= para_.scaling();
}

void Vdwd2::cal_force()
{
ModuleBase::TITLE("Vdwd2", "force");
para_.initset(ucell_);
force_.clear();
force_.resize(ucell_.nat);

auto force = [&](double r,
double R0_sum,
double C6_product,
double r_sqr,
int it1,
int ia1,
const ModuleBase::Vector3<double> &tau1,
const ModuleBase::Vector3<double> &tau2) {
const double tmp_exp = exp(-para_.damping() * (r / R0_sum - 1));
const double tmp_factor = C6_product / pow(r_sqr, 3) / r / (1 + tmp_exp)
* (-6 / r + tmp_exp / (1 + tmp_exp) * para_.damping() / R0_sum);
force_[ucell_.itia2iat(it1, ia1)] += tmp_factor * (tau1 - tau2);
};

index_loops(force);
std::for_each(force_.begin(), force_.end(), [&](ModuleBase::Vector3<double> &f) {
f *= para_.scaling() / ucell_.lat0;
});
}

void Vdwd2::cal_stress()
{
ModuleBase::TITLE("Vdwd2", "stress");
para_.initset(ucell_);
stress_.Zero();

auto stress = [&](double r,
double R0_sum,
double C6_product,
double r_sqr,
int it1,
int ia1,
const ModuleBase::Vector3<double> &tau1,
const ModuleBase::Vector3<double> &tau2) {
const double tmp_exp = exp(-para_.damping() * (r / R0_sum - 1));
const double tmp_factor = C6_product / pow(r_sqr, 3) / r / (1 + tmp_exp)
* (-6 / r + tmp_exp / (1 + tmp_exp) * para_.damping() / R0_sum);
const ModuleBase::Vector3<double> dr = tau2 - tau1;
stress_ += tmp_factor / 2
* ModuleBase::Matrix3(dr.x * dr.x, dr.x * dr.y, dr.x * dr.z,
dr.y * dr.x, dr.y * dr.y, dr.y * dr.z,
dr.z * dr.x, dr.z * dr.y, dr.z * dr.z);
};

index_loops(stress);
stress_ *= para_.scaling() / ucell_.omega;
}

} // namespace vdw

0 comments on commit 64a23b8

Please sign in to comment.