From fa16019edbcde2734a0955ce7935ded62c0fbd34 Mon Sep 17 00:00:00 2001 From: xingliang Date: Thu, 27 Jan 2022 16:14:24 +0800 Subject: [PATCH 01/52] test: add comments and unittest for math_sphbes.h range: source/module_base --- source/module_base/math_sphbes.h | 39 ++++-- source/module_base/test/math_sphbes_test.cpp | 139 +++++++++++++++++++ 2 files changed, 168 insertions(+), 10 deletions(-) create mode 100644 source/module_base/test/math_sphbes_test.cpp diff --git a/source/module_base/math_sphbes.h b/source/module_base/math_sphbes.h index f152e52e659..60245e95d72 100644 --- a/source/module_base/math_sphbes.h +++ b/source/module_base/math_sphbes.h @@ -14,22 +14,41 @@ class Sphbes Sphbes(); ~Sphbes(); + /** + * @brief spherical bessel + * + * @param msh [in] number of grid points + * @param r [in] radial grid (1:msh) + * @param q [in] k_radial + * @param l [in] angular momentum + * @param jl [out] jl(1:msh) spherical bessel function + */ static void Spherical_Bessel ( - const int &msh, //number of grid points - const double *r,//radial grid - const double &q, // - const int &l, //angular momentum - double *jl //jl(1:msh) = j_l(q*r(i)),spherical bessel function + const int &msh, + const double *r, + const double &q, + const int &l, + double *jl ); + /** + * @brief spherical bessel + * + * @param msh [in] number of grid points + * @param r [in] radial grid (1:msh) + * @param q [in] k_radial + * @param l [in] angular momentum + * @param jl [out] jl(1:msh) spherical bessel function + * @param sjp [out] sjp[i] is assigned to be 1.0. i < msh. + */ static void Spherical_Bessel ( - const int &msh, //number of grid points - const double *r,//radial grid - const double &q, // - const int &l, //angular momentum - double *sj, //jl(1:msh) = j_l(q*r(i)),spherical bessel function + const int &msh, + const double *r, + const double &q, + const int &l, + double *sj, double *sjp ); diff --git a/source/module_base/test/math_sphbes_test.cpp b/source/module_base/test/math_sphbes_test.cpp new file mode 100644 index 00000000000..936e1a34245 --- /dev/null +++ b/source/module_base/test/math_sphbes_test.cpp @@ -0,0 +1,139 @@ +#include"../math_sphbes.h" +#include + +#ifdef __MPI +#include"mpi.h" +#endif + +#include"gtest/gtest.h" + +#define doublethreshold 1e-12 + + +/************************************************ +* unit test of class Integral +***********************************************/ + +/** + * Note: this unit test try to ensure the invariance + * of the spherical Bessel produced by class Sphbes, + * and the reference results are produced by ABACUS + * at 2022-1-27. + * + * Tested function: Spherical_Bessel. + * + */ + +double mean(const double* vect, const int totN) +{ + double meanv = 0.0; + for (int i=0; i< totN; ++i) {meanv += vect[i]/totN;} + return meanv; +} + +class Sphbes : public testing::Test +{ + protected: + + int msh = 700; + int l0 = 0; + int l1 = 1; + int l2 = 2; + int l3 = 3; + int l4 = 4; + int l5 = 5; + int l6 = 6; + int l7 = 7; + double q = 1.0; + double *r = new double[msh]; + double *jl = new double[msh]; + + void SetUp() + { + for(int i=0; i Date: Sat, 20 Nov 2021 15:13:48 +0800 Subject: [PATCH 02/52] add cal_grad_overlap in orb11 --- source/src_lcao/center2_orb-orb11.cpp | 91 +++++++++++++++++++++++++++ source/src_lcao/center2_orb-orb11.h | 8 ++- 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/source/src_lcao/center2_orb-orb11.cpp b/source/src_lcao/center2_orb-orb11.cpp index 5b24f12f132..31be64167eb 100644 --- a/source/src_lcao/center2_orb-orb11.cpp +++ b/source/src_lcao/center2_orb-orb11.cpp @@ -151,3 +151,94 @@ double Center2_Orb::Orb11::cal_overlap( return overlap; } + +ModuleBase::Vector3 Center2_Orb::Orb11::cal_grad_overlap( //caoyu add 2021-11-19 + const ModuleBase::Vector3 &RA, const ModuleBase::Vector3 &RB, + const int& mA, const int& mB) const +{ + const double tiny1 = 1e-12; // why 1e-12? + const double tiny2 = 1e-10; // why 1e-10? + + const ModuleBase::Vector3 delta_R = RB-RA; + const double distance_true = delta_R.norm(); + const double distance = (distance_true>=tiny1) ? distance_true : distance_true+tiny1; + const double RcutA = nA.getRcut(); + const double RcutB = nB.getRcut(); + if( distance > (RcutA + RcutB) ) return 0.0; + + const int LA = nA.getL(); + const int LB = nB.getL(); + + std::vector rly; + std::vector> tmp_grly; + std::vector> grly; + ModuleBase::Ylm::grad_rl_sph_harm( + LA + LB, + delta_R.x, delta_R.y, delta_R.z, + rly, tmp_grly); + for (const auto& tmp_ele : tmp_grly) + { + ModuleBase::Vector3 ele(tmp_ele[0], tmp_ele[1], tmp_ele[2]); + grly.push_back(ele); + } + + ModuleBase::Vector3 grad_overlap(0.0, 0.0, 0.0); + + for (const auto& tb_r : Table_r) + { + const int LAB = tb_r.first; + for( int mAB=0; mAB!=2*LAB+1; ++mAB ) + // const int mAB = mA + mB; + { + const double Gaunt_real_A_B_AB = + MGT.Gaunt_Coefficients ( + MGT.get_lm_index(LA,mA), + MGT.get_lm_index(LB,mB), + MGT.get_lm_index(LAB,mAB)); + if( 0==Gaunt_real_A_B_AB ) continue; + + const double ylm_solid = rly[ MGT.get_lm_index(LAB, mAB) ]; + if( 0==ylm_solid ) continue; + const double ylm_real = + (distance > tiny2) ? + ylm_solid / pow(distance,LAB) : + ylm_solid; + + const ModuleBase::Vector3 gylm_solid = grly[MGT.get_lm_index(LAB, mAB)]; + if( 0==gylm_solid.norm2() ) continue; //???? + const ModuleBase::Vector3 gylm_real = + (distance > tiny2) ? + gylm_solid / pow(distance,LAB) : + gylm_solid; + + const double i_exp = std::pow(-1.0, (LA - LB - LAB) / 2); + + const double Interp_Tlm = + (distance > tiny2) ? + ModuleBase::PolyInt::Polynomial_Interpolation( + ModuleBase::GlobalFunc::VECTOR_TO_PTR(tb_r.second), + MOT.get_rmesh(RcutA, RcutB), + MOT.dr, + distance ) : + tb_r.second.at(0); + + const double grad_Interp_Tlm = + (distance > tiny2) ? + ModuleBase::PolyInt::Polynomial_Interpolation( + ModuleBase::GlobalFunc::VECTOR_TO_PTR(this->Table_dr.at(LAB)), + MOT.get_rmesh(RcutA, RcutB), + MOT.dr, + distance) //Interp(Table_dr) + - Interp_Tlm * LAB / distance : + 0.0; + + grad_overlap += + i_exp // pow(2*PI,1.5) + * Gaunt_real_A_B_AB + * (Interp_Tlm * gylm_real + + grad_Interp_Tlm * ylm_real *delta_R / distance); + } + } + + return grad_overlap; +} \ No newline at end of file diff --git a/source/src_lcao/center2_orb-orb11.h b/source/src_lcao/center2_orb-orb11.h index d35b2efd22a..b10f58c044a 100644 --- a/source/src_lcao/center2_orb-orb11.h +++ b/source/src_lcao/center2_orb-orb11.h @@ -34,8 +34,12 @@ class Center2_Orb::Orb11 double cal_overlap( const ModuleBase::Vector3 &RA, const ModuleBase::Vector3 &RB, // unit: Bohr const int &mA, const int &mB) const; - - private: + + ModuleBase::Vector3 cal_grad_overlap( //caoyu add 2021-11-19 + const ModuleBase::Vector3 &RA, const ModuleBase::Vector3 &RB, // unit: Bohr + const int& mA, const int& mB) const; + +private: const Numerical_Orbital_Lm &nA; const Numerical_Orbital_Lm &nB; From 8914737181f5ca6408e11c47cdc7a45e6ee69a4d Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Thu, 13 Jan 2022 18:42:41 +0800 Subject: [PATCH 03/52] fix a bug of Center2Orb::Orb11::cal_grad_overlap --- source/src_lcao/center2_orb-orb11.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/src_lcao/center2_orb-orb11.cpp b/source/src_lcao/center2_orb-orb11.cpp index 31be64167eb..7873df7eac2 100644 --- a/source/src_lcao/center2_orb-orb11.cpp +++ b/source/src_lcao/center2_orb-orb11.cpp @@ -198,14 +198,12 @@ ModuleBase::Vector3 Center2_Orb::Orb11::cal_grad_overlap( //caoyu add if( 0==Gaunt_real_A_B_AB ) continue; const double ylm_solid = rly[ MGT.get_lm_index(LAB, mAB) ]; - if( 0==ylm_solid ) continue; const double ylm_real = (distance > tiny2) ? ylm_solid / pow(distance,LAB) : ylm_solid; const ModuleBase::Vector3 gylm_solid = grly[MGT.get_lm_index(LAB, mAB)]; - if( 0==gylm_solid.norm2() ) continue; //???? const ModuleBase::Vector3 gylm_real = (distance > tiny2) ? gylm_solid / pow(distance,LAB) : From 92d23f79a76375fa773a4841849a48f1470142f6 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Thu, 13 Jan 2022 18:57:26 +0800 Subject: [PATCH 04/52] gtest: double-check 2-center-int results from 2 classes --- tests/module_orb/src/1_snap_equal_test.cpp | 73 ++++++++++ tests/module_orb/src/Makefile | 13 +- tests/module_orb/src/Makefile.Objects | 8 +- tests/module_orb/src/ORB_unittest.cpp | 160 ++++++++++++++++----- tests/module_orb/src/ORB_unittest.h | 48 ++++++- tests/module_orb/src/main.cpp | 41 +----- 6 files changed, 262 insertions(+), 81 deletions(-) create mode 100644 tests/module_orb/src/1_snap_equal_test.cpp diff --git a/tests/module_orb/src/1_snap_equal_test.cpp b/tests/module_orb/src/1_snap_equal_test.cpp new file mode 100644 index 00000000000..3831b08ea20 --- /dev/null +++ b/tests/module_orb/src/1_snap_equal_test.cpp @@ -0,0 +1,73 @@ +#include +#include"ORB_unittest.h" + +//Test whether the 2-center-int results +// and its derivative from two clases are equal. +// - ORB_gen_table::snap_psipsi(job=0) and Center2_Orb::Orb11::cal_overlap +// - ORB_gen_table::snap_psipsi(job=1) and Center2_Orb::Orb11::cal_grad_overlap +TEST_F(test_orb, equal_test) +{ + + this->set_center2orbs(); + //equal test + //orb + double olm_0[1] = { 0 }; + double olm_1[3] = { 0,0,0 }; + //center2orb + double clm_0 = 0; + ModuleBase::Vector3 clm_1; + + //test parameters + const double rmax = 5; //Ry + srand((unsigned)time(NULL)); + ModuleBase::Vector3 R1(0, 0, 0); + ModuleBase::Vector3 R2(randr(rmax), randr(rmax), randr(rmax)); + std::cout << "random R2=(" << R2.x << "," << R2.y << "," << R2.z << ")" << std::endl; + ModuleBase::Vector3 dR = ModuleBase::Vector3(0.001, 0.001, 0.001); + //4. calculate overlap and grad_overlap by both methods + int T1 = 0; + + for (int T2 = 0;T2 < ORB.get_ntype();++T2) + { + for (int L1 = 0;L1 < ORB.Phi[T1].getLmax();++L1) + { + for (int N1 = 0;N1 < ORB.Phi[T1].getNchi(L1);++N1) + { + for (int L2 = 0;L2 < ORB.Phi[T2].getLmax();++L2) + { + for (int N2 = 0;N2 < ORB.Phi[T2].getNchi(L2);++N2) + { + for (int m1 = 0;m1 < 2 * L1 + 1;++m1) + { + for (int m2 = 0;m2 < 2 * L2 + 1;++m2) + { + OGT.snap_psipsi( + ORB, olm_0, 0, 'S', + R1, T1, L1, m1, N1, + R2, T2, L2, m2, N2, + 1, NULL); + OGT.snap_psipsi( + ORB, olm_1, 1, 'S', + R1, T1, L1, m1, N1, + R2, T2, L2, m2, N2, + 1, NULL); + //std::cout << this->mock_center2_orb11[T1][T2][L1][N1][L2][N2]->cal_overlap(R1, R2, m1, m2); + clm_0 = + test_center2_orb11[T1][T2][L1][N1][L2][N2]->cal_overlap(R1, R2, m1, m2); + clm_1 = + test_center2_orb11[T1][T2][L1][N1][L2][N2]->cal_grad_overlap(R1, R2, m1, m2); + EXPECT_NEAR(olm_0[0], clm_0, 1e-10); + EXPECT_NEAR(olm_1[0], clm_1.x, 1e-10); + EXPECT_NEAR(olm_1[1], clm_1.y, 1e-10); + EXPECT_NEAR(olm_1[2], clm_1.z, 1e-10); + ModuleBase::GlobalFunc::ZEROS(olm_1, 3); + } + } + } + + } + } + } + } +} + diff --git a/tests/module_orb/src/Makefile b/tests/module_orb/src/Makefile index 5887dd5b5cb..5fb47a29ffc 100644 --- a/tests/module_orb/src/Makefile +++ b/tests/module_orb/src/Makefile @@ -5,8 +5,10 @@ #========================== CPLUSPLUS = icpc CPLUSPLUS_MPI = icpc -FFTW_DIR = /home/liwenfei/codes/FFTW3 +FFTW_DIR = /home/fortneu49/soft/fftw-3.3.8 OBJ_DIR = orb_obj +GTEST_DIR = /usr/local/lib +GMOCK_DIR = /usr/local/lib NP = 1 #========================== @@ -20,7 +22,7 @@ FFTW_LIB = -L${FFTW_LIB_DIR} -lfftw3 -Wl,-rpath=${FFTW_LIB_DIR} #========================== # LIBS and INCLUDES #========================== -LIBS = -lifcore -lm -lpthread ${FFTW_LIB} +LIBS = -lifcore -lm -lpthread ${FFTW_LIB} ${GTEST_DIR}/libgtest.a ${GMOCK_DIR}/libgmock.a #========================== # OPTIMIZE OPTIONS @@ -29,13 +31,14 @@ INCLUDES = -I. -Icommands -I${FFTW_INCLUDE_DIR} # -pedantic turns off more extensions and generates more warnings # -xHost generates instructions for the highest instruction set available on the compilation host processor -OPTS = ${INCLUDES} -Ofast -std=c++11 -simd -march=native -xHost -m64 -qopenmp -Werror -Wall -pedantic -g +OPTS = ${INCLUDES} -Ofast -std=c++14 -simd -march=native -m64 -qopenmp -Werror -Wall -pedantic -g include Makefile.Objects VPATH=../../../source/module_base\ :../../../source/src_global\ :../../../source/module_orbital\ +:../../../source/src_lcao\ :./\ #========================== @@ -46,6 +49,8 @@ HONG= -DMETIS -DMKL_ILP64 -D__NORMAL -D__ORBITAL ${HONG_FFTW} FP_OBJS_0=main.o\ $(OBJS_BASE)\ $(OBJS_ORBITAL)\ +$(OBJS_CTO)\ +$(OBJS_TEST)\ FP_OBJS=$(patsubst %.o, ${OBJ_DIR}/%.o, ${FP_OBJS_0}) @@ -60,7 +65,7 @@ init : @ if [ ! -d $(OBJ_DIR) ]; then mkdir $(OBJ_DIR); fi @ if [ ! -d $(OBJ_DIR)/README ]; then echo "This directory contains all of the .o files" > $(OBJ_DIR)/README; fi -serial : ${FP_OBJS} +serial : ${FP_OBJS} ${CPLUSPLUS} ${OPTS} $(FP_OBJS) ${LIBS} -o ${VERSION}.x #========================== diff --git a/tests/module_orb/src/Makefile.Objects b/tests/module_orb/src/Makefile.Objects index 81b3dc995f8..0ac16871004 100644 --- a/tests/module_orb/src/Makefile.Objects +++ b/tests/module_orb/src/Makefile.Objects @@ -39,4 +39,10 @@ ORB_table_beta.o\ ORB_table_phi.o\ ORB_table_alpha.o\ ORB_gen_tables.o\ -ORB_unittest.o\ + +OBJS_CTO=center2_orb-orb11.o\ + +OBJS_TEST=ORB_unittest.o\ +1_snap_equal_test.o\ +#2_snap_finite_diff_test.o\ +#3_center2_finite_diff_test.o\ \ No newline at end of file diff --git a/tests/module_orb/src/ORB_unittest.cpp b/tests/module_orb/src/ORB_unittest.cpp index e5e008b8bf2..b0a29ec6104 100644 --- a/tests/module_orb/src/ORB_unittest.cpp +++ b/tests/module_orb/src/ORB_unittest.cpp @@ -1,13 +1,36 @@ #include "ORB_unittest.h" -#include -#include -#include -test_orb::test_orb() -{} +void test_orb::SetUp() +{ + //test constructor + /*Center2_Orb::Orb11 testcto = Center2_Orb::Orb11( + ORB.Phi[0].PhiLN(0, 0), + ORB.Phi[0].PhiLN(0, 0), + OGT.MOT, Center2_MGT);*/ + // 1. setup orbitals + this->ofs_running.open("log.txt"); + this->count_ntype(); + this->set_files(); + this->set_ekcut(); + + + //2. setup 2-center-integral tables by basic methods + // not including center2orb, it will be set up when needed + // in some test cases. + this->set_orbs(); + //this->set_center2orbs(); + +} -test_orb::~test_orb() -{} +void test_orb::TearDown() +{ + int* nproj = new int[ORB.get_ntype()]; + for (int i = 0;i < ORB.get_ntype();++i) + nproj[i] = 0; + ooo.clear_after_ions(OGT, ORB, 0, nproj); + delete[] nproj; + return; +} void test_orb::set_ekcut() { @@ -17,7 +40,7 @@ void test_orb::set_ekcut() lcao_ecut=0.0; std::ifstream in_ao; - for(int it=0;it> ofile; @@ -103,7 +137,7 @@ void test_orb::count_ntype() ModuleBase::GlobalFunc::SCAN_BEGIN(ifs,"ATOMIC_SPECIES"); - ntype = 0; + ntype_read = 0; std::string x; ifs.rdstate(); @@ -120,12 +154,70 @@ void test_orb::count_ntype() if(x=="LATTICE_CONSTANT" || x=="NUMERICAL_ORBITAL" || x=="LATTICE_VECTORS" || x=="ATOMIC_POSITIONS") break; std::string tmpid=x.substr(0,1); - if(!x.empty() && tmpid!="#") ntype++; + if(!x.empty() && tmpid!="#") ntype_read++; } - std::cout << "ntype : "<< ntype << std::endl; ifs.close(); return; } + +void test_orb::set_center2orbs() +{ + //1. setup Gaunt coeffs + Center2_MGT.init_Gaunt_CH( lmax ); + Center2_MGT.init_Gaunt(lmax); + //2. setup tables + + for (int TA = 0; TA < ORB.get_ntype(); TA++) + { + for (int TB = 0; TB < ORB.get_ntype(); TB++) + { + for (int LA=0; LA <= ORB.Phi[TA].getLmax() ; LA++) + { + for (int NA = 0; NA < ORB.Phi[TA].getNchi(LA); ++NA) + { + for (int LB = 0; LB <= ORB.Phi[TB].getLmax(); ++LB) + { + for (int NB = 0; NB < ORB.Phi[TB].getNchi(LB); ++NB) + { + this->set_single_c2o(TA, TB, LA, NA, LB, NB); + // test_center2_orb11[TA][TB][LA][NA][LB].insert( + // make_pair(NB, MockCenter2Orb11(ORB.Phi[TA].PhiLN(LA, NA), + // ORB.Phi[TB].PhiLN(LB, NB), OGT.MOT, Center2_MGT))); + } + } + } + } + } + } + + for (auto& co1 : this->test_center2_orb11) + for( auto &co2 : co1.second ) + for( auto &co3 : co2.second ) + for( auto &co4 : co3.second ) + for( auto &co5 : co4.second ) + for( auto &co6 : co5.second ) + co6.second->init_radial_table(); +} +template +void test_orb::set_single_c2o(int TA, int TB, int LA, int NA, int LB, int NB) +{ + this->test_center2_orb11[TA][TB][LA][NA][LB].insert( + make_pair(NB, std::make_unique(ORB.Phi[TA].PhiLN(LA, NA), + ORB.Phi[TB].PhiLN(LB, NB), OGT.MOT, Center2_MGT))); +} +double test_orb::randr(double Rmax) +{ + return double(rand()) / double(RAND_MAX) * Rmax; +} + +/* +void test_orb::test() { + ModuleBase::Vector3 R1(0, 0, 0); + ModuleBase::Vector3 R2(randr(50), randr(50), randr(50)); + std::cout << "random R2=(" << R2.x << "," << R2.y << "," << R2.z << ")" << std::endl; + ModuleBase::Vector3 dR = ModuleBase::Vector3(0.001, 0.001, 0.001); + std::cout << this->test_center2_orb11[0][0][0][0][0][0].cal_overlap(R1, R2, 0, 0); +}*/ \ No newline at end of file diff --git a/tests/module_orb/src/ORB_unittest.h b/tests/module_orb/src/ORB_unittest.h index e9a2e59d9fb..d2935b2868f 100644 --- a/tests/module_orb/src/ORB_unittest.h +++ b/tests/module_orb/src/ORB_unittest.h @@ -1,31 +1,64 @@ +#ifndef _ORBTEST_ +#define _ORBTEST_ + +#include "gtest/gtest.h" #include "../../../source/module_orbital/ORB_control.h" #include "../../../source/module_base/global_function.h" +#include "../../../source/src_lcao/center2_orb-orb11.h" +//#include "mock_center2.h" #include #include #include + +#include +#include +#include + using namespace std; -class test_orb +class test_orb : public testing::Test { +protected: + void SetUp() override; + void TearDown() override; public: - test_orb(); - ~test_orb(); - LCAO_Orbitals ORB; ORB_gen_tables OGT; + ORB_gaunt_table Center2_MGT; //gaunt table used in center2orb ORB_control ooo; - std::ofstream ofs_running; + std::map < size_t, + std::map>>>>>> test_center2_orb11; +/* + std::map < size_t, + std::map>>>>>> mock_center2_orb11; +*/ void count_ntype(); //from STRU, count types of elements void set_files(); //from STRU, read names of LCAO files void set_ekcut(); //from LCAO files, read and set ekcut - void set_orbs(const double &lat0_in); //interface to Read_PAO + void set_orbs(); //interface to Read_PAO + void set_center2orbs(); //interface to Center2orb + template + void set_single_c2o(int TA, int TB, int LA, int NA, int LB, int NB); + double randr(double Rmax); + void gen_table_center2(); + bool force_flag = 0; int my_rank = 0; - int ntype; + int ntype_read; double lcao_ecut = 0; // (Ry) double lcao_dk = 0.01; @@ -38,3 +71,4 @@ class test_orb int lmax=1; double lat0 = 1.0; }; +#endif diff --git a/tests/module_orb/src/main.cpp b/tests/module_orb/src/main.cpp index e05342eb71c..b99c1456e3b 100644 --- a/tests/module_orb/src/main.cpp +++ b/tests/module_orb/src/main.cpp @@ -1,37 +1,8 @@ -//#include "timer.h" -#include -#include -#include "ORB_unittest.h" -#include "../../../source/module_base/global_variable.h" -#include "../../../source/module_base/global_file.h" - -void calculate(); - -int main(int argc, char **argv) +#include "gtest/gtest.h" +int main(int argc, char** argv) { - - std::cout << "Hello, this is the ORB module of ABACUS." << std::endl; - - calculate(); - - return 0; -} - -void calculate() -{ - GlobalV::BASIS_TYPE = "lcao"; - - test_orb test; - - test.ofs_running.open("log.txt"); - test.count_ntype(); - test.set_files(); - test.set_ekcut(); - test.set_orbs(test.lat0); - - std::cout << "--------------------" << std::endl; - std::cout << " Have a great day! " << std::endl; - std::cout << "--------------------" << std::endl; - - return; +// test if the result of Center_2_Orb::Orb11:cal_overlap +// is equal to the result of ORB_gen_table::snap_psipsi + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } From 332932888f77f716616007a43fd21e4bc5984fe7 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Fri, 14 Jan 2022 11:45:07 +0800 Subject: [PATCH 05/52] move to module_orbital/test and add CMakeLists.txt --- source/module_orbital/CMakeLists.txt | 5 +++++ .../module_orbital/test}/1_snap_equal_test.cpp | 0 source/module_orbital/test/CMakeLists.txt | 4 ++++ .../module_orbital/test}/GaAs/As_dojo.orb | 0 .../module_orbital/test}/GaAs/Ga_dojo.orb | 0 .../module_orbital/test}/GaAs/README | 0 .../module_orbital/test}/GaAs/STRU | 0 .../src => source/module_orbital/test}/Makefile | 0 .../module_orbital/test}/Makefile.Objects | 0 .../module_orbital/test}/ORB_unittest.cpp | 13 +++++++------ .../module_orbital/test}/ORB_unittest.h | 1 + source/module_orbital/{2_UnitTests => test}/README | 0 .../src => source/module_orbital/test}/main.cpp | 0 13 files changed, 17 insertions(+), 6 deletions(-) rename {tests/module_orb/src => source/module_orbital/test}/1_snap_equal_test.cpp (100%) create mode 100644 source/module_orbital/test/CMakeLists.txt rename {tests/module_orb => source/module_orbital/test}/GaAs/As_dojo.orb (100%) rename {tests/module_orb => source/module_orbital/test}/GaAs/Ga_dojo.orb (100%) rename {tests/module_orb => source/module_orbital/test}/GaAs/README (100%) rename {tests/module_orb => source/module_orbital/test}/GaAs/STRU (100%) rename {tests/module_orb/src => source/module_orbital/test}/Makefile (100%) rename {tests/module_orb/src => source/module_orbital/test}/Makefile.Objects (100%) rename {tests/module_orb/src => source/module_orbital/test}/ORB_unittest.cpp (93%) rename {tests/module_orb/src => source/module_orbital/test}/ORB_unittest.h (96%) rename source/module_orbital/{2_UnitTests => test}/README (100%) rename {tests/module_orb/src => source/module_orbital/test}/main.cpp (100%) diff --git a/source/module_orbital/CMakeLists.txt b/source/module_orbital/CMakeLists.txt index 3a753030426..f7fa6e8d385 100644 --- a/source/module_orbital/CMakeLists.txt +++ b/source/module_orbital/CMakeLists.txt @@ -13,3 +13,8 @@ add_library( ORB_table_beta.cpp ORB_table_phi.cpp ) + +IF (BUILD_TESTING) + set(CMAKE_CXX_STANDARD 14) + add_subdirectory(test) +endif() \ No newline at end of file diff --git a/tests/module_orb/src/1_snap_equal_test.cpp b/source/module_orbital/test/1_snap_equal_test.cpp similarity index 100% rename from tests/module_orb/src/1_snap_equal_test.cpp rename to source/module_orbital/test/1_snap_equal_test.cpp diff --git a/source/module_orbital/test/CMakeLists.txt b/source/module_orbital/test/CMakeLists.txt new file mode 100644 index 00000000000..a17d16a8d36 --- /dev/null +++ b/source/module_orbital/test/CMakeLists.txt @@ -0,0 +1,4 @@ +AddTest( + TARGET orbital_equal_test # this is the executable file name of the test + SOURCES 1_snap_equal_test.cpp ORB_unittest.cpp +) \ No newline at end of file diff --git a/tests/module_orb/GaAs/As_dojo.orb b/source/module_orbital/test/GaAs/As_dojo.orb similarity index 100% rename from tests/module_orb/GaAs/As_dojo.orb rename to source/module_orbital/test/GaAs/As_dojo.orb diff --git a/tests/module_orb/GaAs/Ga_dojo.orb b/source/module_orbital/test/GaAs/Ga_dojo.orb similarity index 100% rename from tests/module_orb/GaAs/Ga_dojo.orb rename to source/module_orbital/test/GaAs/Ga_dojo.orb diff --git a/tests/module_orb/GaAs/README b/source/module_orbital/test/GaAs/README similarity index 100% rename from tests/module_orb/GaAs/README rename to source/module_orbital/test/GaAs/README diff --git a/tests/module_orb/GaAs/STRU b/source/module_orbital/test/GaAs/STRU similarity index 100% rename from tests/module_orb/GaAs/STRU rename to source/module_orbital/test/GaAs/STRU diff --git a/tests/module_orb/src/Makefile b/source/module_orbital/test/Makefile similarity index 100% rename from tests/module_orb/src/Makefile rename to source/module_orbital/test/Makefile diff --git a/tests/module_orb/src/Makefile.Objects b/source/module_orbital/test/Makefile.Objects similarity index 100% rename from tests/module_orb/src/Makefile.Objects rename to source/module_orbital/test/Makefile.Objects diff --git a/tests/module_orb/src/ORB_unittest.cpp b/source/module_orbital/test/ORB_unittest.cpp similarity index 93% rename from tests/module_orb/src/ORB_unittest.cpp rename to source/module_orbital/test/ORB_unittest.cpp index b0a29ec6104..52d50c9f1ed 100644 --- a/tests/module_orb/src/ORB_unittest.cpp +++ b/source/module_orbital/test/ORB_unittest.cpp @@ -8,8 +8,8 @@ void test_orb::SetUp() ORB.Phi[0].PhiLN(0, 0), OGT.MOT, Center2_MGT);*/ // 1. setup orbitals - this->ofs_running.open("log.txt"); - this->count_ntype(); + this->ofs_running.open("log.txt"); + this->count_ntype(); this->set_files(); this->set_ekcut(); @@ -44,7 +44,7 @@ void test_orb::set_ekcut() { double ek_current; - in_ao.open(ORB.orbital_file[it].c_str()); + in_ao.open((this->case_dir+ORB.orbital_file[it].c_str())); if(!in_ao) { std::cout << "error : cannot find LCAO file : " << ORB.orbital_file[it] << std::endl; @@ -107,7 +107,7 @@ void test_orb::set_orbs() void test_orb::set_files() { std::cout << "read names of atomic basis set files" << std::endl; - std::ifstream ifs("STRU",std::ios::in); + std::ifstream ifs((this->case_dir + "STRU"),std::ios::in); ModuleBase::GlobalFunc::SCAN_BEGIN(ifs,"NUMERICAL_ORBITAL"); ORB.read_in_flag = true; @@ -127,7 +127,8 @@ void test_orb::set_files() void test_orb::count_ntype() { std::cout << "count number of atom types" << std::endl; - std::ifstream ifs("STRU",std::ios::in); + std::cout << this->case_dir +"STRU" << std::endl; + std::ifstream ifs( (this->case_dir+ "STRU"), std::ios::in); if (!ifs) { @@ -156,7 +157,7 @@ void test_orb::count_ntype() std::string tmpid=x.substr(0,1); if(!x.empty() && tmpid!="#") ntype_read++; } - + std::cout <<"ntype="<< ntype_read << std::endl; ifs.close(); return; diff --git a/tests/module_orb/src/ORB_unittest.h b/source/module_orbital/test/ORB_unittest.h similarity index 96% rename from tests/module_orb/src/ORB_unittest.h rename to source/module_orbital/test/ORB_unittest.h index d2935b2868f..cf4843dbb16 100644 --- a/tests/module_orb/src/ORB_unittest.h +++ b/source/module_orbital/test/ORB_unittest.h @@ -70,5 +70,6 @@ class test_orb : public testing::Test int lmax=1; double lat0 = 1.0; + string case_dir = "../../../../source/module_orbital/test/GaAs/"; }; #endif diff --git a/source/module_orbital/2_UnitTests/README b/source/module_orbital/test/README similarity index 100% rename from source/module_orbital/2_UnitTests/README rename to source/module_orbital/test/README diff --git a/tests/module_orb/src/main.cpp b/source/module_orbital/test/main.cpp similarity index 100% rename from tests/module_orb/src/main.cpp rename to source/module_orbital/test/main.cpp From ffcdfad4c0b02810e0a8dc6936cfccb685222dbe Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Fri, 14 Jan 2022 12:14:17 +0800 Subject: [PATCH 06/52] fix: orbital file not found --- source/module_orbital/test/ORB_unittest.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_orbital/test/ORB_unittest.cpp b/source/module_orbital/test/ORB_unittest.cpp index 52d50c9f1ed..86d506ca5f6 100644 --- a/source/module_orbital/test/ORB_unittest.cpp +++ b/source/module_orbital/test/ORB_unittest.cpp @@ -49,7 +49,7 @@ void test_orb::set_ekcut() { std::cout << "error : cannot find LCAO file : " << ORB.orbital_file[it] << std::endl; } - + ORB.orbital_file[it] = this->case_dir + ORB.orbital_file[it].c_str(); string word; while (in_ao.good()) { From 99071d694da7cbcb9eaf65f5caf3183deaba428f Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Wed, 19 Jan 2022 20:37:45 +0800 Subject: [PATCH 07/52] add MPI_init --- source/module_orbital/test/main.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/source/module_orbital/test/main.cpp b/source/module_orbital/test/main.cpp index b99c1456e3b..76d17ac22db 100644 --- a/source/module_orbital/test/main.cpp +++ b/source/module_orbital/test/main.cpp @@ -1,8 +1,14 @@ #include "gtest/gtest.h" int main(int argc, char** argv) { -// test if the result of Center_2_Orb::Orb11:cal_overlap -// is equal to the result of ORB_gen_table::snap_psipsi - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +#ifdef __MPI + MPI_Init(&argc,&argv); +#endif + + testing::InitGoogleTest(&argc, argv); + int result = RUN_ALL_TESTS(); + +#ifdef __MPI + MPI_Finalize(); +#endif } From f3b7fd483024fd62ffb610d3485494f3b19ad347 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Sat, 22 Jan 2022 16:29:02 +0800 Subject: [PATCH 08/52] move main func to 1_snap_equal_test.cpp --- source/module_orbital/test/1_snap_equal_test.cpp | 13 +++++++++++++ source/module_orbital/test/main.cpp | 14 -------------- 2 files changed, 13 insertions(+), 14 deletions(-) delete mode 100644 source/module_orbital/test/main.cpp diff --git a/source/module_orbital/test/1_snap_equal_test.cpp b/source/module_orbital/test/1_snap_equal_test.cpp index 3831b08ea20..e506034c3f5 100644 --- a/source/module_orbital/test/1_snap_equal_test.cpp +++ b/source/module_orbital/test/1_snap_equal_test.cpp @@ -71,3 +71,16 @@ TEST_F(test_orb, equal_test) } } +int main(int argc, char** argv) +{ +#ifdef __MPI + MPI_Init(&argc,&argv); +#endif + + testing::InitGoogleTest(&argc, argv); + int result = RUN_ALL_TESTS(); + +#ifdef __MPI + MPI_Finalize(); +#endif +} diff --git a/source/module_orbital/test/main.cpp b/source/module_orbital/test/main.cpp deleted file mode 100644 index 76d17ac22db..00000000000 --- a/source/module_orbital/test/main.cpp +++ /dev/null @@ -1,14 +0,0 @@ -#include "gtest/gtest.h" -int main(int argc, char** argv) -{ -#ifdef __MPI - MPI_Init(&argc,&argv); -#endif - - testing::InitGoogleTest(&argc, argv); - int result = RUN_ALL_TESTS(); - -#ifdef __MPI - MPI_Finalize(); -#endif -} From 5e99edea74c40b1675b8af42096728437f6af327 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Fri, 28 Jan 2022 20:37:49 +0800 Subject: [PATCH 09/52] modify according to the review --- source/module_orbital/CMakeLists.txt | 1 + source/module_orbital/test/1_snap_equal_test.cpp | 1 + source/src_lcao/center2_orb-orb11.cpp | 9 +++++---- source/src_lcao/center2_orb-orb11.h | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/source/module_orbital/CMakeLists.txt b/source/module_orbital/CMakeLists.txt index f7fa6e8d385..bd630fa4f0e 100644 --- a/source/module_orbital/CMakeLists.txt +++ b/source/module_orbital/CMakeLists.txt @@ -14,6 +14,7 @@ add_library( ORB_table_phi.cpp ) +set(CMAKE_CXX_STANDARD_REQUIRED ON) IF (BUILD_TESTING) set(CMAKE_CXX_STANDARD 14) add_subdirectory(test) diff --git a/source/module_orbital/test/1_snap_equal_test.cpp b/source/module_orbital/test/1_snap_equal_test.cpp index e506034c3f5..3d1bfda9f34 100644 --- a/source/module_orbital/test/1_snap_equal_test.cpp +++ b/source/module_orbital/test/1_snap_equal_test.cpp @@ -83,4 +83,5 @@ int main(int argc, char** argv) #ifdef __MPI MPI_Finalize(); #endif + return result; } diff --git a/source/src_lcao/center2_orb-orb11.cpp b/source/src_lcao/center2_orb-orb11.cpp index 7873df7eac2..34311cbdd53 100644 --- a/source/src_lcao/center2_orb-orb11.cpp +++ b/source/src_lcao/center2_orb-orb11.cpp @@ -154,17 +154,18 @@ double Center2_Orb::Orb11::cal_overlap( ModuleBase::Vector3 Center2_Orb::Orb11::cal_grad_overlap( //caoyu add 2021-11-19 const ModuleBase::Vector3 &RA, const ModuleBase::Vector3 &RB, - const int& mA, const int& mB) const + const int& mA, const int& mB) const { - const double tiny1 = 1e-12; // why 1e-12? - const double tiny2 = 1e-10; // why 1e-10? + const double tiny1 = 1e-12; // same as `cal_overlap` + const double tiny2 = 1e-10; //same as `cal_overlap` const ModuleBase::Vector3 delta_R = RB-RA; const double distance_true = delta_R.norm(); const double distance = (distance_true>=tiny1) ? distance_true : distance_true+tiny1; const double RcutA = nA.getRcut(); const double RcutB = nB.getRcut(); - if( distance > (RcutA + RcutB) ) return 0.0; + if( distance > (RcutA + RcutB) ) + return ModuleBase::Vector3(0.0, 0.0, 0.0); const int LA = nA.getL(); const int LB = nB.getL(); diff --git a/source/src_lcao/center2_orb-orb11.h b/source/src_lcao/center2_orb-orb11.h index b10f58c044a..6266ee82772 100644 --- a/source/src_lcao/center2_orb-orb11.h +++ b/source/src_lcao/center2_orb-orb11.h @@ -37,7 +37,7 @@ class Center2_Orb::Orb11 ModuleBase::Vector3 cal_grad_overlap( //caoyu add 2021-11-19 const ModuleBase::Vector3 &RA, const ModuleBase::Vector3 &RB, // unit: Bohr - const int& mA, const int& mB) const; + const int& mA, const int& mB) const; private: From 7a535a57f711bf1dd2fd90743a24d44af02cf430 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Tue, 8 Feb 2022 21:44:09 +0800 Subject: [PATCH 10/52] modify CMakeLists.txt according to #676 --- .../module_orbital/test/1_snap_equal_test.cpp | 17 +------- source/module_orbital/test/CMakeLists.txt | 42 ++++++++++++++++++- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/source/module_orbital/test/1_snap_equal_test.cpp b/source/module_orbital/test/1_snap_equal_test.cpp index 3d1bfda9f34..a3fe5423cb7 100644 --- a/source/module_orbital/test/1_snap_equal_test.cpp +++ b/source/module_orbital/test/1_snap_equal_test.cpp @@ -69,19 +69,4 @@ TEST_F(test_orb, equal_test) } } } -} - -int main(int argc, char** argv) -{ -#ifdef __MPI - MPI_Init(&argc,&argv); -#endif - - testing::InitGoogleTest(&argc, argv); - int result = RUN_ALL_TESTS(); - -#ifdef __MPI - MPI_Finalize(); -#endif - return result; -} +} \ No newline at end of file diff --git a/source/module_orbital/test/CMakeLists.txt b/source/module_orbital/test/CMakeLists.txt index a17d16a8d36..072a4579623 100644 --- a/source/module_orbital/test/CMakeLists.txt +++ b/source/module_orbital/test/CMakeLists.txt @@ -1,4 +1,44 @@ +remove_definitions(-D__MPI) + +list(APPEND depend_files + ../../module_base/math_integral.cpp + ../../module_base/math_sphbes.cpp + ../../module_base/math_polyint.cpp + ../../module_base/math_ylmreal.cpp + ../../module_base/ylm.cpp + ../../module_base/memory.cpp + ../../module_base/complexarray.cpp + ../../module_base/complexmatrix.cpp + ../../module_base/matrix.cpp + ../../module_base/realarray.cpp + ../../module_base/intarray.cpp + ../../module_base/sph_bessel.cpp + ../../module_base/sph_bessel_recursive-d1.cpp + ../../module_base/sph_bessel_recursive-d2.cpp + ../../module_base/tool_title.cpp + ../../module_base/tool_quit.cpp + ../../module_base/tool_check.cpp + ../../module_base/timer.cpp + ../../module_base/mathzone_add1.cpp + ../../module_base/global_variable.cpp + ../../module_base/global_function.cpp + ../../module_base/global_file.cpp + ../ORB_control.cpp + ../ORB_read.cpp + ../ORB_atomic.cpp + ../ORB_atomic_lm.cpp + ../ORB_nonlocal.cpp + ../ORB_nonlocal_lm.cpp + ../ORB_gaunt_table.cpp + ../ORB_table_beta.cpp + ../ORB_table_phi.cpp + ../ORB_table_alpha.cpp + ../ORB_gen_tables.cpp + ../../src_lcao/center2_orb-orb11.cpp + ) AddTest( - TARGET orbital_equal_test # this is the executable file name of the test + TARGET orbital_equal_test + LIBS ${math_libs} SOURCES 1_snap_equal_test.cpp ORB_unittest.cpp + ${depend_files} ) \ No newline at end of file From 238828a4eb3fd1d0484e95964e033fe720bb5c75 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Tue, 8 Feb 2022 21:45:25 +0800 Subject: [PATCH 11/52] fix compile error caused by not linking lib in module_md/test --- source/module_md/test/CMakeLists.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/source/module_md/test/CMakeLists.txt b/source/module_md/test/CMakeLists.txt index 3458a0e6e79..4304377ddc2 100644 --- a/source/module_md/test/CMakeLists.txt +++ b/source/module_md/test/CMakeLists.txt @@ -42,18 +42,21 @@ list(APPEND depend_files AddTest( TARGET md_LJ_pot + LIBS ${math_libs} SOURCES LJ_pot_test.cpp ${depend_files} ) AddTest( TARGET md_func + LIBS ${math_libs} SOURCES MD_func_test.cpp ${depend_files} ) AddTest( TARGET md_fire + LIBS ${math_libs} SOURCES FIRE_test.cpp ../verlet.cpp ../FIRE.cpp @@ -62,6 +65,7 @@ AddTest( AddTest( TARGET md_nve + LIBS ${math_libs} SOURCES NVE_test.cpp ../verlet.cpp ../NVE.cpp @@ -70,6 +74,7 @@ AddTest( AddTest( TARGET md_nvt_ads + LIBS ${math_libs} SOURCES NVT_ADS_test.cpp ../verlet.cpp ../NVT_ADS.cpp @@ -78,6 +83,7 @@ AddTest( AddTest( TARGET md_nvt_nhc + LIBS ${math_libs} SOURCES NVT_NHC_test.cpp ../verlet.cpp ../NVT_NHC.cpp @@ -86,6 +92,7 @@ AddTest( AddTest( TARGET md_msst + LIBS ${math_libs} SOURCES MSST_test.cpp ../verlet.cpp ../MSST.cpp @@ -94,6 +101,7 @@ AddTest( AddTest( TARGET md_lgv + LIBS ${math_libs} SOURCES Langevin_test.cpp ../verlet.cpp ../Langevin.cpp From e31807d0ac9dd44dce9e973ebb5f326bb1c52954 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Wed, 9 Feb 2022 14:45:26 +0800 Subject: [PATCH 12/52] fix 'file not found' --- source/module_orbital/test/CMakeLists.txt | 4 +++- source/module_orbital/test/ORB_unittest.h | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/source/module_orbital/test/CMakeLists.txt b/source/module_orbital/test/CMakeLists.txt index 072a4579623..9c780a75266 100644 --- a/source/module_orbital/test/CMakeLists.txt +++ b/source/module_orbital/test/CMakeLists.txt @@ -41,4 +41,6 @@ AddTest( LIBS ${math_libs} SOURCES 1_snap_equal_test.cpp ORB_unittest.cpp ${depend_files} -) \ No newline at end of file +) +install(DIRECTORY GaAs DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/../../../tests) +install(DIRECTORY GaAs DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) \ No newline at end of file diff --git a/source/module_orbital/test/ORB_unittest.h b/source/module_orbital/test/ORB_unittest.h index cf4843dbb16..d817cf07e84 100644 --- a/source/module_orbital/test/ORB_unittest.h +++ b/source/module_orbital/test/ORB_unittest.h @@ -70,6 +70,6 @@ class test_orb : public testing::Test int lmax=1; double lat0 = 1.0; - string case_dir = "../../../../source/module_orbital/test/GaAs/"; + string case_dir = "./GaAs/"; }; #endif From 7680e5e2a7c5147a40d4dffbd541004efb33b25c Mon Sep 17 00:00:00 2001 From: xingliang Date: Fri, 11 Feb 2022 10:08:58 +0800 Subject: [PATCH 13/52] test: add the unittest of module_base/matrix.h range: source/module_base/test/matrix_test.cpp --- source/module_base/test/CMakeLists.txt | 5 + source/module_base/test/matrix_test.cpp | 334 ++++++++++++++++++++++++ 2 files changed, 339 insertions(+) create mode 100644 source/module_base/test/matrix_test.cpp diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index 1241410e700..b7e9a9b7281 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -38,3 +38,8 @@ AddTest( LIBS ${math_libs} SOURCES complexmatrix_test.cpp ../complexmatrix.cpp ../matrix.cpp ) +AddTest( + TARGET base_matrix + LIBS ${math_libs} + SOURCES matrix_test.cpp ../matrix.cpp +) \ No newline at end of file diff --git a/source/module_base/test/matrix_test.cpp b/source/module_base/test/matrix_test.cpp new file mode 100644 index 00000000000..35789941a36 --- /dev/null +++ b/source/module_base/test/matrix_test.cpp @@ -0,0 +1,334 @@ +#include"../matrix.h" +#include"gtest/gtest.h" + +/************************************************ +* unit test of class matrix and related functions +***********************************************/ + +/** + * - Tested functions of class matrix: + * - constructor: + * - constructed by nrow and ncloumn + * - constructed by a matrix + * - constructed by the rvalue of a matrix + * - function create + * - operator "=": assigned by a matrix or the rvalue of a matrix + * - operator "()": access the element + * - operator "*=", "+=", "-=" + * - function trace_on + * - function zero_out + * - function max/min/absmax + * - function norm + * - function print (not called in abacus, no need to test) + * + * - Tested functions related to class matrix + * - operator "+", "-", "*" between two matrixs + * - operator "*" between a double and a matrix, and reverse. + * - function transpose + * - function trace_on + * - function mdot + * + */ + +//a mock function of WARNING_QUIT, to avoid the uncorrected call by matrix.cpp at line 37. +namespace ModuleBase +{ + void WARNING_QUIT(const std::string &file,const std::string &description) {return ;} +} + +class matrixTest : public testing::Test +{ + protected: + ModuleBase::matrix m23a,m33a,m33b,m33c,m34a,m34b; + + void SetUp() + { + m23a.create(2,3); + for (int i=1;i<=6;++i) {m23a.c[i-1] = i*1.0;} + + m33a.create(3,3); + for (int i=1;i<=9;++i) {m33a.c[i-1] = i*1.0;} + + m33b.create(3,3); + for (int i=1;i<=9;++i) {m33b.c[i-1] = i*11.1;} + + m33c.create(3,3,true); + m34a.create(3,4,true); + m34b.create(3,4,true); + } + +}; + +TEST(matrix,Constructornrnc) +{ + ModuleBase::matrix m(3,4,true); + EXPECT_EQ(m.nr,3); + EXPECT_EQ(m.nc,4); + EXPECT_DOUBLE_EQ(m(0,0),0.0); + //EXPECT_DEATH(ModuleBase::matrix m(0,1),""); + //EXPECT_DEATH(ModuleBase::matrix m(1,0),""); + //EXPECT_DEATH(ModuleBase::matrix m(-1,1),""); + //EXPECT_DEATH(ModuleBase::matrix m(1,-1),""); +} + +TEST_F(matrixTest,constructorMatrix) +{ + ModuleBase::matrix m(m33a); + int mnr = m.nr; + EXPECT_EQ(mnr,m33a.nr); + EXPECT_EQ(m.nc,m33a.nc); + for (int i=0;i<9;++i) + { + EXPECT_DOUBLE_EQ(m.c[i],m33a.c[i]); + } +} + +TEST_F(matrixTest,constructorMtrixRvalue) +{ + + ModuleBase::matrix m(3.0*m33a); + EXPECT_EQ(m.nr,m33a.nr); + EXPECT_EQ(m.nc,m33a.nc); + for (int i=0;i<9;++i) + { + EXPECT_DOUBLE_EQ(m.c[i],m33a.c[i] * 3.0); + } +} + +TEST_F(matrixTest,create) +{ + m33a.create(13,14,true); + EXPECT_EQ(m33a.nr,13); + EXPECT_EQ(m33a.nc,14); + for(int i=0;i<13*14;++i) + { + EXPECT_DOUBLE_EQ(m33a.c[i],0.0); + } +} + +TEST_F(matrixTest,operatorEqualMatrix) +{ + ModuleBase::matrix m; + m = m33a; + EXPECT_EQ(m.nr,m33a.nr); + EXPECT_EQ(m.nc,m33a.nc); + for (int i=0;i<9;++i) + { + EXPECT_DOUBLE_EQ(m.c[i],m33a.c[i]); + } + + m23a = m33a; + EXPECT_EQ(m23a.nr,m33a.nr); + EXPECT_EQ(m23a.nc,m33a.nc); +} + +TEST_F(matrixTest,operatorEqualMatrixRvalue) +{ + ModuleBase::matrix m; + m = 3.0 * m33a; + EXPECT_EQ(m.nr,m33a.nr); + EXPECT_EQ(m.nc,m33a.nc); + for (int i=0;i<9;++i) + { + EXPECT_DOUBLE_EQ(m.c[i],m33a.c[i] * 3.0); + } +} + +TEST_F(matrixTest,operatorParentheses) +{ + //EXPECT_DEATH(m33a(3,3),""); + //EXPECT_DEATH(m33a(-1,0),""); + m33a(0,0) = 1.1; + EXPECT_DOUBLE_EQ(m33a(0,0),1.1); +} + +TEST_F(matrixTest,operatorMultiplyEqual) +{ + m33b = m33a; + m33a *= 11.1; + for (int i=0;i Date: Fri, 11 Feb 2022 14:48:51 +0800 Subject: [PATCH 14/52] modify the test name following "case upper camel case naming" delete the redundant comment lines. --- source/module_base/test/matrix_test.cpp | 46 +++++++++++-------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/source/module_base/test/matrix_test.cpp b/source/module_base/test/matrix_test.cpp index 35789941a36..b1365da4565 100644 --- a/source/module_base/test/matrix_test.cpp +++ b/source/module_base/test/matrix_test.cpp @@ -59,19 +59,15 @@ class matrixTest : public testing::Test }; -TEST(matrix,Constructornrnc) +TEST(matrix,ConstructorNrNc) { ModuleBase::matrix m(3,4,true); EXPECT_EQ(m.nr,3); EXPECT_EQ(m.nc,4); - EXPECT_DOUBLE_EQ(m(0,0),0.0); - //EXPECT_DEATH(ModuleBase::matrix m(0,1),""); - //EXPECT_DEATH(ModuleBase::matrix m(1,0),""); - //EXPECT_DEATH(ModuleBase::matrix m(-1,1),""); - //EXPECT_DEATH(ModuleBase::matrix m(1,-1),""); + EXPECT_DOUBLE_EQ(m(0,0),0.0); } -TEST_F(matrixTest,constructorMatrix) +TEST_F(matrixTest,ConstructorMatrix) { ModuleBase::matrix m(m33a); int mnr = m.nr; @@ -83,7 +79,7 @@ TEST_F(matrixTest,constructorMatrix) } } -TEST_F(matrixTest,constructorMtrixRvalue) +TEST_F(matrixTest,ConstructorMtrixRValue) { ModuleBase::matrix m(3.0*m33a); @@ -95,7 +91,7 @@ TEST_F(matrixTest,constructorMtrixRvalue) } } -TEST_F(matrixTest,create) +TEST_F(matrixTest,Create) { m33a.create(13,14,true); EXPECT_EQ(m33a.nr,13); @@ -106,7 +102,7 @@ TEST_F(matrixTest,create) } } -TEST_F(matrixTest,operatorEqualMatrix) +TEST_F(matrixTest,OperatorEqualMatrix) { ModuleBase::matrix m; m = m33a; @@ -122,7 +118,7 @@ TEST_F(matrixTest,operatorEqualMatrix) EXPECT_EQ(m23a.nc,m33a.nc); } -TEST_F(matrixTest,operatorEqualMatrixRvalue) +TEST_F(matrixTest,OperatorEqualMatrixRvalue) { ModuleBase::matrix m; m = 3.0 * m33a; @@ -134,7 +130,7 @@ TEST_F(matrixTest,operatorEqualMatrixRvalue) } } -TEST_F(matrixTest,operatorParentheses) +TEST_F(matrixTest,OperatorParentheses) { //EXPECT_DEATH(m33a(3,3),""); //EXPECT_DEATH(m33a(-1,0),""); @@ -142,7 +138,7 @@ TEST_F(matrixTest,operatorParentheses) EXPECT_DOUBLE_EQ(m33a(0,0),1.1); } -TEST_F(matrixTest,operatorMultiplyEqual) +TEST_F(matrixTest,OperatorMultiplyEqual) { m33b = m33a; m33a *= 11.1; @@ -155,7 +151,7 @@ TEST_F(matrixTest,operatorMultiplyEqual) } } -TEST_F(matrixTest,operatorPlusEqual) +TEST_F(matrixTest,OperatorPlusEqual) { EXPECT_DEATH(m33a += m34a,""); @@ -170,7 +166,7 @@ TEST_F(matrixTest,operatorPlusEqual) } } -TEST_F(matrixTest,operatorMinusEqual) +TEST_F(matrixTest,OperatorMinusEqual) { EXPECT_DEATH(m33a -= m34a,""); @@ -185,7 +181,7 @@ TEST_F(matrixTest,operatorMinusEqual) } } -TEST_F(matrixTest,classMatrixTraceOn) +TEST_F(matrixTest,ClassMatrixTraceOn) { m33a(0,0) = 1.1; m33a(1,1) = 2.2; @@ -193,7 +189,7 @@ TEST_F(matrixTest,classMatrixTraceOn) EXPECT_DOUBLE_EQ(m33a.trace_on(),8.8); } -TEST_F(matrixTest,classMatrixZeroOut) +TEST_F(matrixTest,ClassMatrixZeroOut) { m33a.zero_out(); for (int i=0;i Date: Fri, 11 Feb 2022 15:34:01 +0800 Subject: [PATCH 15/52] test: add UT and annotations of realArray --- source/module_base/realarray.cpp | 22 ++- source/module_base/realarray.h | 194 ++++++++++++++------- source/module_base/test/realarray_test.cpp | 184 +++++++++++++++++++ 3 files changed, 338 insertions(+), 62 deletions(-) create mode 100644 source/module_base/test/realarray_test.cpp diff --git a/source/module_base/realarray.cpp b/source/module_base/realarray.cpp index e6a0e60ddbe..757bb9ee386 100644 --- a/source/module_base/realarray.cpp +++ b/source/module_base/realarray.cpp @@ -53,6 +53,22 @@ realArray::realArray(const int d1,const int d2,const int d3,const int d4) ++arrayCount; } +realArray::realArray(const realArray &cd) +{ + this->size = cd.getSize(); + this->ptr = new double[size]; + for (int i = 0; i < size; i++) + this->ptr[i] = cd.ptr[i]; + this->dim = cd.dim; + this->bound1 = cd.bound1; + this->bound2 = cd.bound2; + this->bound3 = cd.bound3; + this->bound4 = cd.bound4; + + ++arrayCount; +} + + //******************************** // // Destructor for class realArray @@ -65,8 +81,8 @@ realArray ::~realArray() void realArray::freemem() { - delete [] ptr; - ptr = NULL; + delete [] ptr; + ptr = NULL; } void realArray::create(const int d1,const int d2,const int d3,const int d4) @@ -173,4 +189,4 @@ void realArray::zero_out(void) return; } -} \ No newline at end of file +} diff --git a/source/module_base/realarray.h b/source/module_base/realarray.h index 45cc5fc84c4..d55feb86472 100644 --- a/source/module_base/realarray.h +++ b/source/module_base/realarray.h @@ -5,82 +5,158 @@ #ifndef REALARRAY_H #define REALARRAY_H -#include +#include #include #include -#include +#include #ifdef _MCD_CHECK //#include "./src_parallel/mcd.h" #endif - namespace ModuleBase { - +/** + * @brief double float array + * + */ class realArray { -public: - double *ptr; - - realArray(const int d1 = 1 ,const int d2 = 1,const int d3 = 1); - realArray(const int d1, const int d2,const int d3,const int d4); - ~realArray(); - - void create(const int d1,const int d2,const int d3); - void create(const int d1,const int d2,const int d3,const int d4); - - const realArray &operator=(const realArray &right); - const realArray &operator=(const double &right); - - double &operator()(const int d1,const int d2,const int d3); - double &operator()(const int d1,const int d2,const int d3,const int d4); - - const double &operator()(const int d1,const int d2,const int d3)const; - const double &operator()(const int d1,const int d2,const int d3,const int d4)const; - - void zero_out(void); - - int getSize() const - { return size;} - - int getDim() const - { return dim;} - - int getBound1() const - { return bound1;} - - int getBound2() const - { return bound2;} - - int getBound3() const - { return bound3;} - - int getBound4() const - { return bound4;} - - static int getArrayCount(void) - { return arrayCount;} - -private: - int size; - int dim; - int bound1, bound2, bound3, bound4; - static int arrayCount; - - void freemem(); + public: + double *ptr; + + realArray(const int d1 = 1, const int d2 = 1, const int d3 = 1); + realArray(const int d1, const int d2, const int d3, const int d4); + ~realArray(); + + /** + * @brief create 3 dimensional real array + * + * @param[in] d1 The first dimension size + * @param[in] d2 The second dimension size + * @param[in] d3 The third dimension size + */ + void create(const int d1, const int d2, const int d3); + void create(const int d1, const int d2, const int d3, const int d4); + + realArray(const realArray &cd); + + /** + * @brief Equal a realArray to another one + * + * @param right + * @return const realArray& + */ + const realArray &operator=(const realArray &right); + /** + * @brief Set all value of an array to a double float number + * + * @param right + * @return const realArray& + */ + const realArray &operator=(const double &right); + + /** + * @brief Access elements by using operator "()" + * + * @param d1 + * @param d2 + * @param d3 + * @return double& + */ + double &operator()(const int d1, const int d2, const int d3); + double &operator()(const int d1, const int d2, const int d3, const int d4); + + /** + * @brief Access elements by using "()" through pointer + * without changing its elements + * + * @param d1 + * @param d2 + * @param d3 + * @return const double& + */ + const double &operator()(const int d1, const int d2, const int d3) const; + const double &operator()(const int d1, const int d2, const int d3, const int d4) const; + + void zero_out(void); + + /** + * @brief Get the Size object + * + * @return int + */ + int getSize() const + { + return size; + } + + /** + * @brief Get the Dim object + * i.e. the dimension of a real array + * + * @return int + */ + int getDim() const + { + return dim; + } + + /** + * @brief Get the Bound1 object + * i.e. the first dimension size + * + * @return int + */ + int getBound1() const + { + return bound1; + } + + int getBound2() const + { + return bound2; + } + + int getBound3() const + { + return bound3; + } + + int getBound4() const + { + return bound4; + } + + /** + * @brief Get the Array Count object + * + * @return int + */ + static int getArrayCount(void) + { + return arrayCount; + } + + private: + int size; + int dim; + int bound1, bound2, bound3, bound4; + static int arrayCount; + + void freemem(); }; //************************************************** // set elements of a as zeros which a is 1_d array. //************************************************** -template -void zeros(T *u,const int n) +template void zeros(T *u, const int n) { - assert(n>0); - for (int i = 0;i < n;i++) u[i] = 0; + assert(n > 0); + for (int i = 0; i < n; i++) + u[i] = 0; } -} +} // namespace ModuleBase -#endif // realArray class +#endif // realArray class diff --git a/source/module_base/test/realarray_test.cpp b/source/module_base/test/realarray_test.cpp new file mode 100644 index 00000000000..a6c861c2fb4 --- /dev/null +++ b/source/module_base/test/realarray_test.cpp @@ -0,0 +1,184 @@ +#include "../realarray.h" +#include "gtest/gtest.h" + +/************************************************ + * unit test of class realArray + ***********************************************/ + +/** + * - Tested Functions: + * - GetArrayCount + * - get the total number of real array created + * - Construct + * - construct real arrays (3 or 4 dimensions) + * - Creat + * - create a real array (3 or 4 dimensions) + * - GetSize + * - get the total size of a real array + * - GetDim + * - get the dimension of a real array + * - ZeroOut + * - set all elements of a real array to zero + * - GetBound + * - get the size of each dimension of a real array + * - ArrayEqReal + * - set all value of an array to a double float + * - ArrayEqArray + * - equal a realarray to another one + * - Parentheses + * - access element by using operator"()" + * - ConstParentheses + * - access element by using "()" through pointer + * - without changing its elements + */ + +class realArrayTest : public testing::Test +{ +protected: + ModuleBase::realArray a3, a4, b3, b4; + double aa = 11.0; + double bb = 1.0; + int count0; + int count1; + const double zero = 0.0; +}; + +TEST_F(realArrayTest,GetArrayCount) +{ + count0 = ModuleBase::realArray::getArrayCount(); + ModuleBase::realArray c3, c4; + count1 = ModuleBase::realArray::getArrayCount(); + EXPECT_EQ((count1-count0),2); +} + +TEST_F(realArrayTest,Construct) +{ + ModuleBase::realArray x3(1,5,3); + ModuleBase::realArray xp3(x3); + ModuleBase::realArray x4(1,7,3,4); + ModuleBase::realArray xp4(x4); + EXPECT_EQ(x3.getSize(),15); + EXPECT_EQ(xp3.getSize(),15); + EXPECT_EQ(x4.getSize(),84); + EXPECT_EQ(xp4.getSize(),84); +} + +TEST_F(realArrayTest,Create) +{ + a3.create(1,2,3); + a4.create(1,2,3,4); + EXPECT_EQ(a3.getSize(),6); + EXPECT_EQ(a4.getSize(),24); +} + +TEST_F(realArrayTest,GetSize) +{ + ModuleBase::realArray a3(1,5,3); + //std::cout<< &a3 << &(this->a3) < Date: Fri, 11 Feb 2022 15:48:11 +0800 Subject: [PATCH 16/52] modify annotations in realarray.h --- source/module_base/realarray.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/module_base/realarray.h b/source/module_base/realarray.h index d55feb86472..1dc2d7deeff 100644 --- a/source/module_base/realarray.h +++ b/source/module_base/realarray.h @@ -79,6 +79,10 @@ class realArray const double &operator()(const int d1, const int d2, const int d3) const; const double &operator()(const int d1, const int d2, const int d3, const int d4) const; + /** + * @brief Set all elements of an IntArray to zero + * + */ void zero_out(void); /** From ceba5b321a2da6972876f236f2c1984476b59048 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 11 Feb 2022 15:50:49 +0800 Subject: [PATCH 17/52] test: add UT and annotations for IntArray class --- source/module_base/intarray.h | 194 ++++++++++---- source/module_base/test/intarray_test.cpp | 308 ++++++++++++++++++++++ 2 files changed, 444 insertions(+), 58 deletions(-) create mode 100644 source/module_base/test/intarray_test.cpp diff --git a/source/module_base/intarray.h b/source/module_base/intarray.h index f8db1250753..64cacf014b7 100644 --- a/source/module_base/intarray.h +++ b/source/module_base/intarray.h @@ -5,10 +5,10 @@ #ifndef INTARRAY_H #define INTARRAY_H -#include +#include #include #include -#include +#include #ifdef _MCD_CHECK //#include "./src_parallel/mcd.h" @@ -16,63 +16,141 @@ namespace ModuleBase { - +/** + * @brief Integer array + * + */ class IntArray { -public: - int *ptr; - - // Constructors for different dimesnions - IntArray(const int d1 = 1, const int d2 = 1); - IntArray(const int d1, const int d2,const int d3); - IntArray(const int d1, const int d2,const int d3,const int d4); - IntArray(const int d1, const int d2,const int d3,const int d4,const int d5); - IntArray(const int d1, const int d2,const int d3,const int d4,const int d5,const int d6); - - ~IntArray(); - - void create(const int d1, const int d2); - void create(const int d1, const int d2, const int d3); - void create(const int d1, const int d2, const int d3, const int d4); - void create(const int d1, const int d2, const int d3, const int d4, const int d5); - void create(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6); - - const IntArray &operator=(const IntArray &right); - const IntArray &operator=(const int &right); - - int &operator()(const int d1, const int d2); - int &operator()(const int d1, const int d2, const int d3); - int &operator()(const int d1, const int d2, const int d3,const int d4); - int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5); - int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6); - - const int &operator()(const int d1,const int d2)const; - const int &operator()(const int d1,const int d2,const int d3)const; - const int &operator()(const int d1,const int d2,const int d3,const int d4)const; - const int &operator()(const int d1,const int d2,const int d3,const int d4, const int d5)const; - const int &operator()(const int d1,const int d2,const int d3,const int d4, const int d5, const int d6)const; - - void zero_out(void); - - int getSize() const{ return size;} - int getDim() const{ return dim;} - int getBound1() const{ return bound1;} - int getBound2() const{ return bound2;} - int getBound3() const{ return bound3;} - int getBound4() const { return bound4;} - int getBound5() const { return bound5;} - int getBound6() const { return bound6;} - - static int getArrayCount(void) - { return arrayCount;} - -private: - int size; - int dim; - int bound1, bound2, bound3, bound4, bound5, bound6; - static int arrayCount; - void freemem(); + public: + int *ptr; + + /** + * @brief Construct a new Int Array object + * + * @param d1 The first dimension size + * @param d2 The second dimension size + */ + IntArray(const int d1 = 1, const int d2 = 1); + IntArray(const int d1, const int d2, const int d3); + IntArray(const int d1, const int d2, const int d3, const int d4); + IntArray(const int d1, const int d2, const int d3, const int d4, const int d5); + IntArray(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6); + + ~IntArray(); + + /** + * @brief Create integer arrays + * + * @param[in] d1 + * @param[in] d2 + */ + void create(const int d1, const int d2); + void create(const int d1, const int d2, const int d3); + void create(const int d1, const int d2, const int d3, const int d4); + void create(const int d1, const int d2, const int d3, const int d4, const int d5); + void create(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6); + + /** + * @brief Equal an IntArray to another one + * + * @param right + * @return const IntArray& + */ + const IntArray &operator=(const IntArray &right); + + /** + * @brief Equal all elements of an IntArray to an + * integer + * + * @param right + * @return const IntArray& + */ + const IntArray &operator=(const int &right); + + /** + * @brief Access elements by using operator "()" + * + * @param d1 + * @param d2 + * @return int& + */ + int &operator()(const int d1, const int d2); + int &operator()(const int d1, const int d2, const int d3); + int &operator()(const int d1, const int d2, const int d3, const int d4); + int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5); + int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6); + + /** + * @brief Access elements by using "()" through pointer + * without changing its elements + * + * @param d1 + * @param d2 + * @return const int& + */ + const int &operator()(const int d1, const int d2) const; + const int &operator()(const int d1, const int d2, const int d3) const; + const int &operator()(const int d1, const int d2, const int d3, const int d4) const; + const int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5) const; + const int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6) const; + + /** + * @brief Set all elements of an IntArray to zero + * + */ + void zero_out(void); + + int getSize() const + { + return size; + } + int getDim() const + { + return dim; + } + int getBound1() const + { + return bound1; + } + int getBound2() const + { + return bound2; + } + int getBound3() const + { + return bound3; + } + int getBound4() const + { + return bound4; + } + int getBound5() const + { + return bound5; + } + int getBound6() const + { + return bound6; + } + + /** + * @brief Get the Array Count object + * + * @return int + */ + static int getArrayCount(void) + { + return arrayCount; + } + + private: + int size; + int dim; + int bound1, bound2, bound3, bound4, bound5, bound6; + static int arrayCount; + void freemem(); }; -} +} // namespace ModuleBase -#endif // IntArray class +#endif // IntArray class diff --git a/source/module_base/test/intarray_test.cpp b/source/module_base/test/intarray_test.cpp new file mode 100644 index 00000000000..d6694abea38 --- /dev/null +++ b/source/module_base/test/intarray_test.cpp @@ -0,0 +1,308 @@ +#include "../intarray.h" +#include "gtest/gtest.h" + +/************************************************ + * unit test of class IntArray + ***********************************************/ + +/** + * - Tested Functions: + * - Construct + * - construct an int array (2 to 6 dimensions) + * - Creat + * - create an int array (2 to 6 dimensions) + * - GetArrayCount + * - get the total number of int array created + * - GetSize + * - get the total size of an int array + * - GetDim + * - get the dimension of an int array + * - ZeroOut + * - set all elements of an int array to zero + * - GetBound + * - get the size of each dimension of an int array + * - ArrayEqReal + * - set all value of an array to an int number + * - ArrayEqArray + * - equal an intarray to another intarray + * - Parentheses + * - access element by using operator"()" + * - ConstParentheses + * - access element by using "()" through pointer + * - without changing its elements + */ + +class IntArrayTest : public testing::Test +{ +protected: + ModuleBase::IntArray a2, a3, a4, a5, a6; + int aa = 11; + int bb = 1; + int count0; + int count1; + const int zero = 0; +}; + +TEST_F(IntArrayTest,GetArrayCount) +{ + count0 = ModuleBase::IntArray::getArrayCount(); + ModuleBase::IntArray c3, c4; + count1 = ModuleBase::IntArray::getArrayCount(); + EXPECT_EQ((count1-count0),2); +} + +TEST_F(IntArrayTest,Construct) +{ + ModuleBase::IntArray x2(1,5); + ModuleBase::IntArray x3(1,5,3); + ModuleBase::IntArray x4(1,7,3,4); + ModuleBase::IntArray x5(1,5,3,8,2); + ModuleBase::IntArray x6(1,7,3,4,3,2); + EXPECT_EQ(x2.getSize(),5); + EXPECT_EQ(x3.getSize(),15); + EXPECT_EQ(x4.getSize(),84); + EXPECT_EQ(x5.getSize(),240); + EXPECT_EQ(x6.getSize(),504); +} + +TEST_F(IntArrayTest,Create) +{ + a2.create(2,1); + a3.create(3,2,1); + a4.create(4,3,2,1); + a5.create(5,4,3,2,1); + a6.create(6,5,4,3,2,1); + EXPECT_EQ(a2.getSize(),2); + EXPECT_EQ(a3.getSize(),6); + EXPECT_EQ(a4.getSize(),24); + EXPECT_EQ(a5.getSize(),120); + EXPECT_EQ(a6.getSize(),720); +} + + +TEST_F(IntArrayTest,GetSize) +{ + ModuleBase::IntArray x3(1,5,3); + ModuleBase::IntArray x4(1,7,3,4); + EXPECT_EQ(x3.getSize(),15); + EXPECT_EQ(x4.getSize(),84); +} + +TEST_F(IntArrayTest,GetDim) +{ + a2.create(2,3); + a3.create(3,5,1); + a4.create(4,3,7,1); + a5.create(5,4,1,2,1); + a6.create(6,5,9,3,2,1); + EXPECT_EQ(a2.getDim(),2); + EXPECT_EQ(a3.getDim(),3); + EXPECT_EQ(a4.getDim(),4); + EXPECT_EQ(a5.getDim(),5); + EXPECT_EQ(a6.getDim(),6); +} + +TEST_F(IntArrayTest,ZeroOut) +{ + a2.create(2,3); + a3.create(3,5,1); + a4.create(4,3,7,1); + a5.create(5,4,1,2,1); + a6.create(6,5,9,3,2,1); + a2.zero_out(); + a3.zero_out(); + a4.zero_out(); + a5.zero_out(); + a6.zero_out(); + for (int i=0;i Date: Fri, 11 Feb 2022 15:52:47 +0800 Subject: [PATCH 18/52] fix a bug caused by memory access --- source/module_orbital/test/ORB_unittest.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_orbital/test/ORB_unittest.cpp b/source/module_orbital/test/ORB_unittest.cpp index 86d506ca5f6..531d921003a 100644 --- a/source/module_orbital/test/ORB_unittest.cpp +++ b/source/module_orbital/test/ORB_unittest.cpp @@ -87,7 +87,7 @@ void test_orb::set_orbs() int* nproj = new int[ORB.get_ntype()]; for (int i = 0;i < ORB.get_ntype();++i) nproj[i] = 0; - const Numerical_Nonlocal beta_[1]; + const Numerical_Nonlocal beta_[ORB.get_ntype()]; ooo.set_orb_tables( ofs_running, From 68fdb1b11753c7b3cf6317a1987cadcff02fd66b Mon Sep 17 00:00:00 2001 From: root Date: Fri, 11 Feb 2022 16:34:13 +0800 Subject: [PATCH 19/52] test: add UT and annotations for class Vector3 --- source/module_base/test/vector3_test.cpp | 722 +++++++++++++++++++++++ source/module_base/vector3.h | 405 ++++++++++--- 2 files changed, 1052 insertions(+), 75 deletions(-) create mode 100644 source/module_base/test/vector3_test.cpp diff --git a/source/module_base/test/vector3_test.cpp b/source/module_base/test/vector3_test.cpp new file mode 100644 index 00000000000..0f2f553c619 --- /dev/null +++ b/source/module_base/test/vector3_test.cpp @@ -0,0 +1,722 @@ +#include "../vector3.h" +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +/************************************************ + * unit test of class Vector3 + ***********************************************/ + +/** + * - Tested Functions: + * - Construct + * - two ways of constructing a 3d vector + * - Set + * - set a 3d vector + * - Equal + * - overload operator "=" for 3d vector + * - PlusEqual + * - overload operator "+=" for 3d vector + * - MinusEqual + * - overload operator "-=" for 3d vector + * - MultiplyEqual + * - overload operator "*=" for (3d vector) * scalar + * - OverEqual + * - overload operator "/=" for (3d vector) / scalar + * - Negative + * - overload operator "-" to get - Vector3 + * - Reverse + * - same as negative + * - Access + * - access elements by using "[]" + * - ConstAccess + * - access elements by using "[]" through pinters + * - withough chaning element values + * - VectorPlus + * - overload operator "+" for two 3d vectors + * - VectorMinus + * - overload operator "-" for two 3d vectors + * - Norm2 + * - get the square of norm of a 3d vector + * - Norm + * - get the norm of a 3d vector + * - Normalize + * - normalize a 3d vector + * - VmultiplyV + * - overload operator "*" to calculate + * - the dot product of two 3d vectors + * - VdotV + * - dot product of two 3d vectors + * - VmultiplyNum + * - overload operator "*" to calculate + * - the product of a 3d vector with a scalar + * - of the product of a scalar with a 3d vector + * - VoverNum + * - overload operator "/" to calculate + * - a 3d vector over a scalar + * - OperatorCaret + * - overload operator "^" to calculate + * - the cross product of two 3d vectors + * - VeqV + * - reload operator "==" to assert + * - the equality between two 3d vectors + * - VneV + * - reload operator "!=" to assert + * - the non-equality between two 3d vectors + * - StdOutV + * - reload operator "<<" to print out + * - a 3d vectors on standard output + * - PrintV + * - print a 3d vectors on standard output + * - with formats + */ + +class Vector3Test : public testing::Test +{ +protected: + double da = 3.0; + double db = 4.0; + double dc = 5.0; + int ia = 3; + int ib = 4; + int ic = 5; + float fa = 3.0; + float fb = 4.0; + float fc = 5.0; + // for capturing stdout + std::string output; +}; + +TEST_F(Vector3Test,Construct) +{ + // double Vector3 + ModuleBase::Vector3 u (da,db,dc); + ModuleBase::Vector3 up (u); + EXPECT_EQ(u.x,3.0); + EXPECT_EQ(u.y,4.0); + EXPECT_EQ(u.z,5.0); + EXPECT_EQ(up.x,3.0); + EXPECT_EQ(up.y,4.0); + EXPECT_EQ(up.z,5.0); + // float Vector3 + ModuleBase::Vector3 v (fa,fb,fc); + ModuleBase::Vector3 vp (v); + EXPECT_EQ(v.x,3.0); + EXPECT_EQ(v.y,4.0); + EXPECT_EQ(v.z,5.0); + EXPECT_EQ(vp.x,3.0); + EXPECT_EQ(vp.y,4.0); + EXPECT_EQ(vp.z,5.0); + // int Vector3 + ModuleBase::Vector3 w (ia,ib,ic); + ModuleBase::Vector3 wp (w); + EXPECT_EQ(w.x,3); + EXPECT_EQ(w.y,4); + EXPECT_EQ(w.z,5); + EXPECT_EQ(wp.x,3); + EXPECT_EQ(wp.y,4); + EXPECT_EQ(wp.z,5); +} + +TEST_F(Vector3Test,Set) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + EXPECT_EQ(u.x,3.0); + EXPECT_EQ(u.y,4.0); + EXPECT_EQ(u.z,5.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + EXPECT_EQ(v.x,3.0); + EXPECT_EQ(v.y,4.0); + EXPECT_EQ(v.z,5.0); + // int Vector3 + ModuleBase::Vector3 w; + w.set(ia,ib,ic); + EXPECT_EQ(w.x,3); + EXPECT_EQ(w.y,4); + EXPECT_EQ(w.z,5); +} + +TEST_F(Vector3Test,Equal) +{ + // double Vector3 + ModuleBase::Vector3 u, up; + u.set(da,db,dc); + up = u; + EXPECT_EQ(up.x,3.0); + EXPECT_EQ(up.y,4.0); + EXPECT_EQ(up.z,5.0); + // float Vector3 + ModuleBase::Vector3 v, vp; + v.set(fa,fb,fc); + vp = v; + EXPECT_EQ(vp.x,3.0); + EXPECT_EQ(vp.y,4.0); + EXPECT_EQ(vp.z,5.0); + // int Vector3 + ModuleBase::Vector3 w, wp; + w.set(ia,ib,ic); + wp = w; + EXPECT_EQ(wp.x,3); + EXPECT_EQ(wp.y,4); + EXPECT_EQ(wp.z,5); +} + +TEST_F(Vector3Test,PlusEqual) +{ + // double Vector3 + ModuleBase::Vector3 u, up; + u.set(da,db,dc); + up.set(da,db,dc); + up += u; + EXPECT_EQ(up.x,6.0); + EXPECT_EQ(up.y,8.0); + EXPECT_EQ(up.z,10.0); + // float Vector3 + ModuleBase::Vector3 v, vp; + v.set(fa,fb,fc); + vp.set(fa,fb,fc); + vp += v; + EXPECT_EQ(vp.x,6.0); + EXPECT_EQ(vp.y,8.0); + EXPECT_EQ(vp.z,10.0); + // int Vector3 + ModuleBase::Vector3 w, wp; + w.set(ia,ib,ic); + wp.set(ia,ib,ic); + wp += w; + EXPECT_EQ(wp.x,6); + EXPECT_EQ(wp.y,8); + EXPECT_EQ(wp.z,10); +} + +TEST_F(Vector3Test,MinusEqual) +{ + // double Vector3 + ModuleBase::Vector3 u, up; + u.set(da,db,dc); + up.set(3*da,3*db,3*dc); + up -= u; + EXPECT_EQ(up.x,6.0); + EXPECT_EQ(up.y,8.0); + EXPECT_EQ(up.z,10.0); + // float Vector3 + ModuleBase::Vector3 v, vp; + v.set(fa,fb,fc); + vp.set(3*fa,3*fb,3*fc); + vp -= v; + EXPECT_EQ(vp.x,6.0); + EXPECT_EQ(vp.y,8.0); + EXPECT_EQ(vp.z,10.0); + // int Vector3 + ModuleBase::Vector3 w, wp; + w.set(ia,ib,ic); + wp.set(3*ia,3*ib,3*ic); + wp -= w; + EXPECT_EQ(wp.x,6); + EXPECT_EQ(wp.y,8); + EXPECT_EQ(wp.z,10); +} + +TEST_F(Vector3Test,MultiplyEqual) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + u *= 2; + EXPECT_EQ(u.x,6.0); + EXPECT_EQ(u.y,8.0); + EXPECT_EQ(u.z,10.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + v *= 2; + EXPECT_EQ(v.x,6.0); + EXPECT_EQ(v.y,8.0); + EXPECT_EQ(v.z,10.0); + // int Vector3 + ModuleBase::Vector3 w; + w.set(ia,ib,ic); + w *= 2; + EXPECT_EQ(w.x,6); + EXPECT_EQ(w.y,8); + EXPECT_EQ(w.z,10); +} + +TEST_F(Vector3Test,OverEqual) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(4*da,4*db,4*dc); + u /= 2; + EXPECT_EQ(u.x,6.0); + EXPECT_EQ(u.y,8.0); + EXPECT_EQ(u.z,10.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(4*fa,4*fb,4*fc); + v /= 2; + EXPECT_EQ(v.x,6.0); + EXPECT_EQ(v.y,8.0); + EXPECT_EQ(v.z,10.0); + // int Vector3 + ModuleBase::Vector3 w; + w.set(4*ia,4*ib,4*ic); + w /= 2; + EXPECT_EQ(w.x,6); + EXPECT_EQ(w.y,8); + EXPECT_EQ(w.z,10); +} + +TEST_F(Vector3Test,Negative) +{ + // double Vector3 + ModuleBase::Vector3 u, up; + u.set(da,db,dc); + up = -u; + EXPECT_EQ(up.x,-3.0); + EXPECT_EQ(up.y,-4.0); + EXPECT_EQ(up.z,-5.0); + // float Vector3 + ModuleBase::Vector3 v, vp; + v.set(fa,fb,fc); + vp = -v; + EXPECT_EQ(vp.x,-3.0); + EXPECT_EQ(vp.y,-4.0); + EXPECT_EQ(vp.z,-5.0); + // int Vector3 + ModuleBase::Vector3 w, wp; + w.set(ia,ib,ic); + wp = -w; + EXPECT_EQ(wp.x,-3); + EXPECT_EQ(wp.y,-4); + EXPECT_EQ(wp.z,-5); +} + +TEST_F(Vector3Test,Access) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + EXPECT_EQ(u[0],3.0); + EXPECT_EQ(u[1],4.0); + EXPECT_EQ(u[2],5.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + EXPECT_EQ(v.x,3.0); + EXPECT_EQ(v.y,4.0); + EXPECT_EQ(v.z,5.0); + // int Vector3 + ModuleBase::Vector3 w; + w.set(ia,ib,ic); + EXPECT_EQ(w.x,3); + EXPECT_EQ(w.y,4); + EXPECT_EQ(w.z,5); +} + + +TEST_F(Vector3Test,ConstAccess) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + const ModuleBase::Vector3 *up(&u); + EXPECT_EQ((*up)[0],3.0); + EXPECT_EQ((*up)[1],4.0); + EXPECT_EQ((*up)[2],5.0); + // float Vector3 + ModuleBase::Vector3 v; + const ModuleBase::Vector3 *vp(&v); + v.set(fa,fb,fc); + EXPECT_EQ((*vp).x,3.0); + EXPECT_EQ((*vp).y,4.0); + EXPECT_EQ((*vp).z,5.0); + // int Vector3 + //ModuleBase::Vector3 w; + //w.set(ia,ib,ic); + //EXPECT_EQ(w.x,3); + //EXPECT_EQ(w.y,4); + //EXPECT_EQ(w.z,5); +} + + +TEST_F(Vector3Test,Reverse) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + u.reverse(); + EXPECT_EQ(u.x,-3.0); + EXPECT_EQ(u.y,-4.0); + EXPECT_EQ(u.z,-5.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + v.reverse(); + EXPECT_EQ(v.x,-3.0); + EXPECT_EQ(v.y,-4.0); + EXPECT_EQ(v.z,-5.0); + // int Vector3 + ModuleBase::Vector3 w; + w.set(ia,ib,ic); + w.reverse(); + EXPECT_EQ(w.x,-3); + EXPECT_EQ(w.y,-4); + EXPECT_EQ(w.z,-5); +} + +TEST_F(Vector3Test,VectorPlus) +{ + // double Vector3 + ModuleBase::Vector3 u,up,upp; + u.set(da,db,dc); + up.set(da,db,dc); + upp = u + up; + EXPECT_EQ(upp[0],6.0); + EXPECT_EQ(upp[1],8.0); + EXPECT_EQ(upp[2],10.0); + // float Vector3 + ModuleBase::Vector3 v,vp,vpp; + v.set(fa,fb,fc); + vp.set(fa,fb,fc); + vpp = v + vp; + EXPECT_EQ(vpp.x,6.0); + EXPECT_EQ(vpp.y,8.0); + EXPECT_EQ(vpp.z,10.0); + // int Vector3 + ModuleBase::Vector3 w,wp,wpp; + w.set(ia,ib,ic); + wp.set(ia,ib,ic); + wpp = w + wp; + EXPECT_EQ(wpp.x,6); + EXPECT_EQ(wpp.y,8); + EXPECT_EQ(wpp.z,10); +} + +TEST_F(Vector3Test,VectorMinus) +{ + // double Vector3 + ModuleBase::Vector3 u,up,upp; + u.set(da,db,dc); + up.set(2*da,2*db,2*dc); + upp = u - up; + EXPECT_EQ(upp[0],-3.0); + EXPECT_EQ(upp[1],-4.0); + EXPECT_EQ(upp[2],-5.0); + // float Vector3 + ModuleBase::Vector3 v,vp,vpp; + v.set(fa,fb,fc); + vp.set(3*fa,3*fb,3*fc); + vpp = v - vp; + EXPECT_EQ(vpp.x,-6.0); + EXPECT_EQ(vpp.y,-8.0); + EXPECT_EQ(vpp.z,-10.0); + // int Vector3 + ModuleBase::Vector3 w,wp,wpp; + w.set(3*ia,3*ib,3*ic); + wp.set(ia,ib,ic); + wpp = w - wp; + EXPECT_EQ(wpp.x,6); + EXPECT_EQ(wpp.y,8); + EXPECT_EQ(wpp.z,10); +} + +TEST_F(Vector3Test,Norm2) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + EXPECT_EQ(u.norm2(),50.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + EXPECT_EQ(v.norm2(),50.0); + // int Vector3 + ModuleBase::Vector3 w; + w.set(ia,ib,ic); + EXPECT_EQ(w.norm2(),50); +} + + +TEST_F(Vector3Test,Norm) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + double nm = u.norm(); + double nm2= sqrt(50.0); + EXPECT_DOUBLE_EQ(nm,nm2); + EXPECT_FLOAT_EQ(nm,sqrt(50.0)); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + float nmp = v.norm(); + float nmp2= sqrt(50.0); + EXPECT_FLOAT_EQ(nmp,sqrt(50.0)); +} + + +TEST_F(Vector3Test,Normalize) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + u.normalize(); + EXPECT_DOUBLE_EQ(u.norm(),1.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + v.normalize(); + EXPECT_FLOAT_EQ(v.norm(),1.0); +} + +TEST_F(Vector3Test,VmultiplyV) +{ + // double Vector3 + ModuleBase::Vector3 u,up; + u.set(da,db,dc); + up.set(da,db,dc); + double mpd = u * up; + EXPECT_EQ(mpd,50.0); + // float Vector3 + ModuleBase::Vector3 v,vp; + v.set(fa,fb,fc); + vp.set(fa,fb,fc); + float mpf = v*vp; + EXPECT_EQ(mpf,50.0); + // int Vector3 + ModuleBase::Vector3 w,wp; + w.set(ia,ib,ic); + wp.set(ia,ib,ic); + int mpi = w*wp; + EXPECT_EQ(mpf,50); +} + +TEST_F(Vector3Test,VdotV) +{ + // double Vector3 + ModuleBase::Vector3 u,up; + u.set(da,db,dc); + up.set(da,db,dc); + double mpd = dot(u,up); + EXPECT_EQ(mpd,50.0); + // float Vector3 + ModuleBase::Vector3 v,vp; + v.set(fa,fb,fc); + vp.set(fa,fb,fc); + float mpf = dot(v,vp); + EXPECT_EQ(mpf,50.0); + // int Vector3 + ModuleBase::Vector3 w,wp; + w.set(ia,ib,ic); + wp.set(ia,ib,ic); + int mpi = dot(w,wp); + EXPECT_EQ(mpf,50); +} + +TEST_F(Vector3Test,VmultiplyNum) +{ + // double Vector3 + ModuleBase::Vector3 u,up,upp; + u.set(da,db,dc); + double s = 3.0; + up = s*u; upp = u*s; + EXPECT_EQ(upp[0],up[0]); + EXPECT_EQ(upp[1],up[1]); + EXPECT_EQ(upp[2],up[2]); + EXPECT_EQ(upp[0],9.0); + EXPECT_EQ(upp[1],12.0); + EXPECT_EQ(upp[2],15.0); + // float Vector3 + ModuleBase::Vector3 v,vp,vpp; + v.set(fa,fb,fc); + float t = 3.0; + vp = t*v; vpp = v*t; + EXPECT_EQ(vpp[0],vp[0]); + EXPECT_EQ(vpp[1],vp[1]); + EXPECT_EQ(vpp[2],vp[2]); + EXPECT_EQ(vpp[0],9.0); + EXPECT_EQ(vpp[1],12.0); + EXPECT_EQ(vpp[2],15.0); + // int Vector3 + ModuleBase::Vector3 w,wp,wpp; + w.set(ia,ib,ic); + int q = 3; + wp = q*w; wpp = w*q; + EXPECT_EQ(wpp[0],wp[0]); + EXPECT_EQ(wpp[1],wp[1]); + EXPECT_EQ(wpp[2],wp[2]); + EXPECT_EQ(wpp[0],9.0); + EXPECT_EQ(wpp[1],12.0); + EXPECT_EQ(wpp[2],15.0); +} + +TEST_F(Vector3Test,VoverNum) +{ + // double Vector3 + ModuleBase::Vector3 u,up; + u.set(2*da,2*db,2*dc); + double s = 2.0; + up = u/s; + EXPECT_EQ(up.x,3.0); + EXPECT_EQ(up.y,4.0); + EXPECT_EQ(up.z,5.0); + // float Vector3 + ModuleBase::Vector3 v,vp; + v.set(2*fa,2*fb,2*fc); + float t = 2.0; + vp = v/t; + EXPECT_EQ(vp.x,3.0); + EXPECT_EQ(vp.y,4.0); + EXPECT_EQ(vp.z,5.0); + // int Vector3 + ModuleBase::Vector3 w,wp; + w.set(2*ia,2*ib,2*ic); + int q = 2; + wp = w/q; + EXPECT_EQ(wp.x,3); + EXPECT_EQ(wp.y,4); + EXPECT_EQ(wp.z,5); +} + +TEST_F(Vector3Test,OperatorCaret) +{ + // double Vector3 + ModuleBase::Vector3 u,up,upp; + u.set(da,db,dc); + up.set(da,db,dc); + upp = u^up; + EXPECT_EQ(upp.x,u.y*up.z - u.z*up.y); + EXPECT_EQ(upp.y,u.z*up.x - u.x*up.z); + EXPECT_EQ(upp.z,u.x*up.y - u.y*up.x); + // float Vector3 + ModuleBase::Vector3 v,vp,vpp; + v.set(2*fa,2*fb,2*fc); + vp.set(fa,fb,fc); + vpp = v^vp; + EXPECT_EQ(vpp.x,v.y*vp.z - v.z*vp.y); + EXPECT_EQ(vpp.y,v.z*vp.x - v.x*vp.z); + EXPECT_EQ(vpp.z,v.x*vp.y - v.y*vp.x); + // int Vector3 + ModuleBase::Vector3 w,wp,wpp; + w.set(2*ia,2*ib,2*ic); + wp.set(ia,ib,ic); + wpp = w^wp; + EXPECT_EQ(wpp.x,w.y*wp.z - w.z*wp.y); + EXPECT_EQ(wpp.y,w.z*wp.x - w.x*wp.z); + EXPECT_EQ(wpp.z,w.x*wp.y - w.y*wp.x); +} + +TEST_F(Vector3Test,Cross) +{ + // double Vector3 + ModuleBase::Vector3 u,up,upp; + u.set(da,db,dc); + up.set(da,db,dc); + upp = cross(u,up); + EXPECT_EQ(upp.x,u.y*up.z - u.z*up.y); + EXPECT_EQ(upp.y,u.z*up.x - u.x*up.z); + EXPECT_EQ(upp.z,u.x*up.y - u.y*up.x); + // float Vector3 + ModuleBase::Vector3 v,vp,vpp; + v.set(2*fa,2*fb,2*fc); + vp.set(fa,fb,fc); + vpp = cross(v,vp); + EXPECT_EQ(vpp.x,v.y*vp.z - v.z*vp.y); + EXPECT_EQ(vpp.y,v.z*vp.x - v.x*vp.z); + EXPECT_EQ(vpp.z,v.x*vp.y - v.y*vp.x); + // int Vector3 + ModuleBase::Vector3 w,wp,wpp; + w.set(2*ia,2*ib,2*ic); + wp.set(ia,ib,ic); + wpp = cross(w,wp); + EXPECT_EQ(wpp.x,w.y*wp.z - w.z*wp.y); + EXPECT_EQ(wpp.y,w.z*wp.x - w.x*wp.z); + EXPECT_EQ(wpp.z,w.x*wp.y - w.y*wp.x); +} + +TEST_F(Vector3Test,VeqV) +{ + // double Vector3 + ModuleBase::Vector3 u,up; + u.set(da,db,dc); + up.set(da,db,dc); + EXPECT_TRUE(up == u); + // float Vector3 + ModuleBase::Vector3 v,vp; + v.set(fa,fb,fc); + vp.set(fa,fb,fc); + EXPECT_TRUE(vp == v); + // int Vector3 + ModuleBase::Vector3 w,wp; + w.set(ia,ib,ic); + wp.set(ia,ib,ic); + EXPECT_TRUE(wp == w); +} + +TEST_F(Vector3Test,VneV) +{ + // double Vector3 + ModuleBase::Vector3 u,up; + u.set(da,db,dc); + up.set(da,db,2*dc); + EXPECT_TRUE(up != u); + // float Vector3 + ModuleBase::Vector3 v,vp; + v.set(fa,fb,2*fc); + vp.set(fa,fb,fc); + EXPECT_TRUE(vp != v); + // int Vector3 + ModuleBase::Vector3 w,wp; + w.set(ia,ib,2*ic); + wp.set(ia,ib,ic); + EXPECT_TRUE(wp != w); +} + +TEST_F(Vector3Test,StdOutV) +{ + // double Vector3 + ModuleBase::Vector3 u(da,db,dc); + testing::internal::CaptureStdout(); + std::cout << u << std::endl; + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr("(")); + // float Vector3 + ModuleBase::Vector3 v(fa,fb,fc); + testing::internal::CaptureStdout(); + std::cout << v << std::endl; + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr(",")); + // int Vector3 + ModuleBase::Vector3 w(ia,ib,ic); + testing::internal::CaptureStdout(); + std::cout << w << std::endl; + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr(")")); +} + +TEST_F(Vector3Test,PrintV) +{ + // double Vector3 + ModuleBase::Vector3 u(3.1415926,db,dc); + testing::internal::CaptureStdout(); + u.print(); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr("3.1416")); + // float Vector3 + ModuleBase::Vector3 v(fa,fb,3.14); + testing::internal::CaptureStdout(); + v.print(); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr("3.14")); + // int Vector3 + ModuleBase::Vector3 w(ia,101,ic); + testing::internal::CaptureStdout(); + w.print(); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr("101")); +} + diff --git a/source/module_base/vector3.h b/source/module_base/vector3.h index f230f616f3c..bb6502fb59a 100644 --- a/source/module_base/vector3.h +++ b/source/module_base/vector3.h @@ -6,103 +6,358 @@ #endif #include -#include #include +#include namespace ModuleBase { -template -class Vector3 +/** + * @brief 3 elements vector + * + * @tparam T + */ +template class Vector3 { -public: - T x; - T y; - T z; - - Vector3(const T &x1 = 0,const T &y1 = 0,const T &z1 = 0) :x(x1),y(y1),z(z1){}; - Vector3(const Vector3 &v) :x(v.x),y(v.y),z(v.z){}; // Peize Lin add 2018-07-16 - void set(const T &x1, const T &y1,const T &z1) { x = x1; y = y1; z = z1; } - - Vector3& operator =(const Vector3 &u) { x=u.x; y=u.y; z=u.z; return *this; } - Vector3& operator+=(const Vector3 &u) { x+=u.x; y+=u.y; z+=u.z; return *this; } - Vector3& operator-=(const Vector3 &u) { x-=u.x; y-=u.y; z-=u.z; return *this; } - Vector3& operator*=(const Vector3 &u); - Vector3& operator*=(const T &s) { x*=s; y*=s; z*=s; return *this; } - Vector3& operator/=(const Vector3 &u); - Vector3& operator/=(const T &s) { x/=s; y/=s; z/=s; return *this; } - Vector3 operator -() const { return Vector3(-x,-y,-z); } // Peize Lin add 2017-01-10 - - T operator[](int index)const { return (&x)[index]; } - T& operator[](int index) { return (&x)[index]; } - - T norm2(void) const { return x*x + y*y + z*z; } - T norm(void) const { return sqrt(norm2()); } - Vector3& normalize(void){ const T m=norm(); x/=m; y/=m; z/=m; return *this; } // Peize Lin update return 2019-09-08 - Vector3& reverse(void){ x=-x; y=-y; z=-z; return *this; } // Peize Lin update return 2019-09-08 - - void print(void)const ; // mohan add 2009-11-29 + public: + T x; + T y; + T z; + + /** + * @brief Construct a new Vector 3 object + * + * @param x1 + * @param y1 + * @param z1 + */ + Vector3(const T &x1 = 0, const T &y1 = 0, const T &z1 = 0) : x(x1), y(y1), z(z1){}; + Vector3(const Vector3 &v) : x(v.x), y(v.y), z(v.z){}; // Peize Lin add 2018-07-16 + + /** + * @brief set a 3d vector + * + * @param x1 + * @param y1 + * @param z1 + */ + void set(const T &x1, const T &y1, const T &z1) + { + x = x1; + y = y1; + z = z1; + } + + /** + * @brief Overload operator "=" for Vector3 + * + * @param u + * @return Vector3& + */ + Vector3 &operator=(const Vector3 &u) + { + x = u.x; + y = u.y; + z = u.z; + return *this; + } + + /** + * @brief Overload operator "+=" for Vector3 + * + * @param u + * @return Vector3& + */ + Vector3 &operator+=(const Vector3 &u) + { + x += u.x; + y += u.y; + z += u.z; + return *this; + } + + /** + * @brief Overload operator "-=" for Vector3 + * + * @param u + * @return Vector3& + */ + Vector3 &operator-=(const Vector3 &u) + { + x -= u.x; + y -= u.y; + z -= u.z; + return *this; + } + + /** + * @brief Overload operator "*=" for (Vector3)*scalar + * + * @param s + * @return Vector3& + */ + Vector3 &operator*=(const T &s) + { + x *= s; + y *= s; + z *= s; + return *this; + } + + /** + * @brief Overload operator "/=" for (Vector3)/scalar + * + * @param s + * @return Vector3& + */ + Vector3 &operator/=(const T &s) + { + x /= s; + y /= s; + z /= s; + return *this; + } + + /** + * @brief Overload operator "-" to get (-Vector3) + * + * @return Vector3 + */ + Vector3 operator-() const + { + return Vector3(-x, -y, -z); + } // Peize Lin add 2017-01-10 + + /** + * @brief Over load "[]" for accessing elements with pointers + * + * @param index + * @return T + */ + T operator[](int index) const + { + return (&x)[index]; + } + + /** + * @brief Overload operator "[]" for accesing elements + * + * @param index + * @return T& + */ + T &operator[](int index) + { + return (&x)[index]; + } + + /** + * @brief Get the square of nomr of a Vector3 + * + * @return T + */ + T norm2(void) const + { + return x * x + y * y + z * z; + } + + /** + * @brief Get the norm of a Vector3 + * + * @return T + */ + T norm(void) const + { + return sqrt(norm2()); + } + + /** + * @brief Normalize a Vector3 + * + * @return Vector3& + */ + Vector3 &normalize(void) + { + const T m = norm(); + x /= m; + y /= m; + z /= m; + return *this; + } // Peize Lin update return 2019-09-08 + + /** + * @brief Get (-Vector3) + * + * @return Vector3& + */ + Vector3 &reverse(void) + { + x = -x; + y = -y; + z = -z; + return *this; + } // Peize Lin update return 2019-09-08 + + /** + * @brief Print a Vector3 on standard output + * with formats + * + */ + void print(void) const; // mohan add 2009-11-29 }; -template inline Vector3 operator+( const Vector3 &u, const Vector3 &v ) { return Vector3( u.x+v.x, u.y+v.y, u.z+v.z ); } -template inline Vector3 operator-( const Vector3 &u, const Vector3 &v ) { return Vector3( u.x-v.x, u.y-v.y, u.z-v.z ); } -//u.v=(ux*vx)+(uy*vy)+(uz*vz) -template inline T operator*( const Vector3 &u, const Vector3 &v ) { return ( u.x*v.x + u.y*v.y + u.z*v.z ); } -template inline Vector3 operator*( const T &s, const Vector3 &u ) { return Vector3( u.x*s, u.y*s, u.z*s ); } -template inline Vector3 operator*( const Vector3 &u, const T &s ) { return Vector3( u.x*s, u.y*s, u.z*s ); } // mohan add 2009-5-10 -template inline Vector3 operator/( const Vector3 &u, const T &s ) { return Vector3( u.x/s, u.y/s, u.z/s ); } -//u.v=(ux*vx)+(uy*vy)+(uz*vz) -template inline T dot ( const Vector3 &u, const Vector3 &v ) { return ( u.x*v.x + u.y*v.y + u.z*v.z ); } -// | i j k | -// | ux uy uz | -// | vx vy vz | -// u.v=(uy*vz-uz*vy)i+(-ux*vz+uz*vx)j+(ux*vy-uy*vx)k -template inline Vector3 operator^(const Vector3 &u,const Vector3 &v) -{ - return Vector3 ( u.y * v.z - u.z * v.y, - -u.x * v.z + u.z * v.x, - u.x * v.y - u.y * v.x); +/** + * @brief Overload "+" for two Vector3 + * + * @param[in] u + * @param[in] v + * @return Vector3 + */ +template inline Vector3 operator+(const Vector3 &u, const Vector3 &v) +{ + return Vector3(u.x + v.x, u.y + v.y, u.z + v.z); +} + +/** + * @brief Overload "-" for two Vector3 + * + * @param[in] u + * @param[in] v + * @return Vector3 + */ +template inline Vector3 operator-(const Vector3 &u, const Vector3 &v) +{ + return Vector3(u.x - v.x, u.y - v.y, u.z - v.z); +} + +/** + * @brief Overload "*" to calculate the dot product + * of two Vector3 + * + * @param u + * @param v + * @return template + */ +template inline T operator*(const Vector3 &u, const Vector3 &v) +{ + return (u.x * v.x + u.y * v.y + u.z * v.z); } -// | i j k | -// | ux uy uz | -// | vx vy vz | -// u.v=(uy*vz-uz*vy)i+(-ux*vz+uz*vx)j+(ux*vy-uy*vzx)k -template inline Vector3 cross(const Vector3 &u,const Vector3 &v) + +/** + * @brief Overload "*" to calculate (Vector3)*scalar + * + * @param[in] s + * @param[in] u + * @return Vector3 + */ +template inline Vector3 operator*(const T &s, const Vector3 &u) { - return Vector3 ( u.y * v.z - u.z * v.y, - -u.x * v.z + u.z * v.x, - u.x * v.y - u.y * v.x); + return Vector3(u.x * s, u.y * s, u.z * s); } -//s = u.(v x w) -//template T TripleScalarProduct(Vector3 u, Vector3 v, Vector3 w) + +/** + * @brief Overload "*" to calculate scalar*(Vector3) + * + * @param u + * @param s + * @return Vector3 + */ +template inline Vector3 operator*(const Vector3 &u, const T &s) +{ + return Vector3(u.x * s, u.y * s, u.z * s); +} // mohan add 2009-5-10 + +/** + * @brief Overload "/" to calculate Vector3/scalar + * + * @tparam T + * @param u + * @param s + * @return Vector3 + */ +template inline Vector3 operator/(const Vector3 &u, const T &s) +{ + return Vector3(u.x / s, u.y / s, u.z / s); +} + +/** + * @brief Dot productor of two Vector3 + * + * @param u + * @param v + * @return T + * @note u.v=(ux*vx)+(uy*vy)+(uz*vz) + */ +template inline T dot(const Vector3 &u, const Vector3 &v) +{ + return (u.x * v.x + u.y * v.y + u.z * v.z); +} + +/** + * @brief Overload "^" for cross product of two Vector3 + * + * @param u + * @param v + * @return template + * @note + * | i j k | + * | ux uy uz | + * | vx vy vz | + * u.v=(uy*vz-uz*vy)i+(-ux*vz+uz*vx)j+(ux*vy-uy*vx)k + */ +template inline Vector3 operator^(const Vector3 &u, const Vector3 &v) +{ + return Vector3(u.y * v.z - u.z * v.y, -u.x * v.z + u.z * v.x, u.x * v.y - u.y * v.x); +} + +/** + * @brief Cross product of two Vector3 + * + * @param u + * @param v + * @return template + * @note + * | i j k | + * | ux uy uz | + * | vx vy vz | + * u.v=(uy*vz-uz*vy)i+(-ux*vz+uz*vx)j+(ux*vy-uy*vx)k + */ +template inline Vector3 cross(const Vector3 &u, const Vector3 &v) +{ + return Vector3(u.y * v.z - u.z * v.y, -u.x * v.z + u.z * v.x, u.x * v.y - u.y * v.x); +} +// s = u.(v x w) +// template T TripleScalarProduct(Vector3 u, Vector3 v, Vector3 w) //{ // return T((u.x * (v.y * w.z - v.z * w.y)) + // (u.y * (-v.x * w.z + v.z * w.x)) + // (u.z * (v.x * w.y - v.y * w.x))); -//} +// } -//whether m1 != m2 -template inline bool operator !=(const Vector3 &u, const Vector3 &v){ return !(u == v); } -//whether u == v -template inline bool operator ==(const Vector3 &u, const Vector3 &v) +// whether m1 != m2 +template inline bool operator!=(const Vector3 &u, const Vector3 &v) { - if(u.x == v.x && u.y == v.y && u.z == v.z) - return true; - return false; + return !(u == v); +} +// whether u == v +template inline bool operator==(const Vector3 &u, const Vector3 &v) +{ + if (u.x == v.x && u.y == v.y && u.z == v.z) + return true; + return false; } -template void Vector3::print(void)const +template void Vector3::print(void) const { - std::cout.precision(5) ; - std::cout << "(" << std::setw(10) << x << "," << std::setw(10) << y << "," - << std::setw(10) << z << ")" << std::endl ; - return ; + std::cout.precision(5); + std::cout << "(" << std::setw(10) << x << "," << std::setw(10) << y << "," << std::setw(10) << z << ")" + << std::endl; + return; } -template static std::ostream & operator << ( std::ostream &os, const Vector3 &u ) +template static std::ostream &operator<<(std::ostream &os, const Vector3 &u) { - os << "(" << std::setw(10) << u.x << "," << std::setw(10) << u.y << "," << std::setw(10) << u.z << ")"; - return os; + os << "(" << std::setw(10) << u.x << "," << std::setw(10) << u.y << "," << std::setw(10) << u.z << ")"; + return os; } -} +} // namespace ModuleBase #endif From 89fb837298166005b59137e72f302d2a0fb1f3fc Mon Sep 17 00:00:00 2001 From: root Date: Fri, 11 Feb 2022 16:39:50 +0800 Subject: [PATCH 20/52] modify annotations in vector3 class --- source/module_base/test/vector3_test.cpp | 8 ++++---- source/module_base/vector3.h | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/source/module_base/test/vector3_test.cpp b/source/module_base/test/vector3_test.cpp index 0f2f553c619..e806802c702 100644 --- a/source/module_base/test/vector3_test.cpp +++ b/source/module_base/test/vector3_test.cpp @@ -57,13 +57,13 @@ * - overload operator "^" to calculate * - the cross product of two 3d vectors * - VeqV - * - reload operator "==" to assert + * - overload operator "==" to assert * - the equality between two 3d vectors * - VneV - * - reload operator "!=" to assert - * - the non-equality between two 3d vectors + * - overload operator "!=" to assert + * - the inequality between two 3d vectors * - StdOutV - * - reload operator "<<" to print out + * - overload operator "<<" to print out * - a 3d vectors on standard output * - PrintV * - print a 3d vectors on standard output diff --git a/source/module_base/vector3.h b/source/module_base/vector3.h index bb6502fb59a..94a09afcfdb 100644 --- a/source/module_base/vector3.h +++ b/source/module_base/vector3.h @@ -345,6 +345,10 @@ template inline bool operator==(const Vector3 &u, const Vector3 return false; } +/** + * @brief Print a Vector3 on standard output with formats + * + */ template void Vector3::print(void) const { std::cout.precision(5); @@ -352,6 +356,16 @@ template void Vector3::print(void) const << std::endl; return; } + +/** + * @brief Overload "<<" tor print out a + * Vector3 on standard output + * + * @tparam T + * @param[in] os + * @param[in] u + * @return std::ostream& + */ template static std::ostream &operator<<(std::ostream &os, const Vector3 &u) { os << "(" << std::setw(10) << u.x << "," << std::setw(10) << u.y << "," << std::setw(10) << u.z << ")"; From 8fbb5bbd955f7d62b5a2199048f35055cf473788 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 11 Feb 2022 17:08:59 +0800 Subject: [PATCH 21/52] test: add UT and annotations for mathzone class --- source/module_base/mathzone.h | 197 ++++++++++++---------- source/module_base/test/CMakeLists.txt | 17 ++ source/module_base/test/mathzone_test.cpp | 86 ++++++++++ 3 files changed, 208 insertions(+), 92 deletions(-) create mode 100644 source/module_base/test/mathzone_test.cpp diff --git a/source/module_base/mathzone.h b/source/module_base/mathzone.h index 99622ac29a1..6e5180afd8c 100644 --- a/source/module_base/mathzone.h +++ b/source/module_base/mathzone.h @@ -1,69 +1,82 @@ #ifndef MATHZONE_H #define MATHZONE_H -#include "realarray.h" -#include "matrix3.h" #include "global_function.h" -#include -#include +#include "matrix3.h" +#include "realarray.h" + #include #include +#include +#include namespace ModuleBase { +/** + * @brief atomic coordinates conversion functions + * + */ class Mathzone { - public: - + public: Mathzone(); ~Mathzone(); - template - static T Max3(const T &a,const T &b,const T &c) + public: + /** + * @brief Pointwise product of two vectors with same size + * + * @tparam Type + * @param[in] f1 + * @param[in] f2 + * @return std::vector + * @author Peize Lin (2016-08-03) + */ + template + static std::vector Pointwise_Product(const std::vector &f1, const std::vector &f2) { - if (a>=b && a>=c) return a; - else if (b>=a && b>=c) return b; - else if (c>=a && c>=b) return c; - else throw std::runtime_error(ModuleBase::GlobalFunc::TO_STRING(__FILE__)+" line "+ModuleBase::GlobalFunc::TO_STRING(__LINE__)); + assert(f1.size() == f2.size()); + std::vector f(f1.size()); + for (int ir = 0; ir != f.size(); ++ir) + f[ir] = f1[ir] * f2[ir]; + return f; } - // be careful, this can only be used for plane wave - // during parallel calculation - -public: - - - - // Peize Lin add 2016-08-03 - template< typename Type > - static std::vector Pointwise_Product( const std::vector &f1, const std::vector &f2 ) - { - assert(f1.size()==f2.size()); - std::vector f(f1.size()); - for( int ir=0; ir!=f.size(); ++ir ) - f[ir] = f1[ir] * f2[ir]; - return f; - } - -//========================================================== -// MEMBER FUNCTION : -// NAME : Direct_to_Cartesian -// use lattice vector matrix R -// change the direct std::vector (dx,dy,dz) to cartesuab vectir -// (cx,cy,cz) -// (dx,dy,dz) = (cx,cy,cz) * R -// -// NAME : Cartesian_to_Direct -// the same as above -// (cx,cy,cz) = (dx,dy,dz) * R^(-1) -//========================================================== - static inline void Direct_to_Cartesian - ( - const double &dx,const double &dy,const double &dz, - const double &R11,const double &R12,const double &R13, - const double &R21,const double &R22,const double &R23, - const double &R31,const double &R32,const double &R33, - double &cx,double &cy,double &cz) + /** + * @brief change direct coordinate (dx,dy,dz) to + * Cartesian coordinate (cx,cy,cz), (dx,dy,dz) = (cx,cy,cz) * R + * + * @param[in] dx Direct coordinats + * @param[in] dy + * @param[in] dz + * @param[in] R11 Lattice vector matrix R_ij: i_row, j_column + * @param[in] R12 + * @param[in] R13 + * @param[in] R21 + * @param[in] R22 + * @param[in] R23 + * @param[in] R31 + * @param[in] R32 + * @param[in] R33 + * @param[out] cx Cartesian coordinats + * @param[out] cy + * @param[out] cz + */ + static inline void Direct_to_Cartesian(const double &dx, + const double &dy, + const double &dz, + const double &R11, + const double &R12, + const double &R13, + const double &R21, + const double &R22, + const double &R23, + const double &R31, + const double &R32, + const double &R33, + double &cx, + double &cy, + double &cz) { static ModuleBase::Matrix3 lattice_vector; static ModuleBase::Vector3 direct_vec, cartesian_vec; @@ -88,13 +101,41 @@ class Mathzone return; } - static inline void Cartesian_to_Direct - ( - const double &cx,const double &cy,const double &cz, - const double &R11,const double &R12,const double &R13, - const double &R21,const double &R22,const double &R23, - const double &R31,const double &R32,const double &R33, - double &dx,double &dy,double &dz) + /** + * @brief Change Cartesian coordinate (cx,cy,cz) to + * direct coordinate (dx,dy,dz), (cx,cy,cz) = (dx,dy,dz) * R^(-1) + * + * @param[in] cx Cartesian coordinats + * @param[in] cy + * @param[in] cz + * @param[in] R11 Lattice vector matrix R_ij: i_row, j_column + * @param[in] R12 + * @param[in] R13 + * @param[in] R21 + * @param[in] R22 + * @param[in] R23 + * @param[in] R31 + * @param[in] R32 + * @param[in] R33 + * @param[out] dx Direct coordinats + * @param[out] dy + * @param[out] dz + */ + static inline void Cartesian_to_Direct(const double &cx, + const double &cy, + const double &cz, + const double &R11, + const double &R12, + const double &R13, + const double &R21, + const double &R22, + const double &R23, + const double &R31, + const double &R32, + const double &R33, + double &dx, + double &dy, + double &dz) { static ModuleBase::Matrix3 lattice_vector, inv_lat; lattice_vector.e11 = R11; @@ -120,45 +161,17 @@ class Mathzone return; } - - static void To_Polar_Coordinate + void To_Polar_Coordinate ( - const double &x_cartesian, - const double &y_cartesian, - const double &z_cartesian, - double &r, - double &theta, - double &phi - ); - - - // coeff1 * x1 + (1-coeff1) * x2 - // Peize Lin add 2017-08-09 - template< typename T, typename T_coeff > - static T Linear_Mixing( const T & x1, const T & x2, const T_coeff & coeff1 ) - { - return coeff1 * x1 + (1-coeff1) * x2; - } - template< typename T, typename T_coeff > - static std::vector Linear_Mixing( const std::vector & x1, const std::vector & x2, const T_coeff & coeff1 ) - { - assert(x1.size()==x2.size()); - std::vector x; - for( size_t i=0; i!=x1.size(); ++i ) - x.push_back( Linear_Mixing( x1[i], x2[i], coeff1 ) ); - return x; - } - template< typename T1, typename T2, typename T_coeff > - static std::map Linear_Mixing( const std::map & x1, const std::map & x2, const T_coeff & coeff1 ) - { - std::map x; - for( const auto & x1i : x1 ) - x.insert( make_pair( x1i.first, Linear_Mixing( x1i.second, x2.at(x1i.first), coeff1 ) ) ); - return x; - } + const double &x_cartesian, + const double &y_cartesian, + const double &z_cartesian, + double &r, + double &theta, + double &phi); }; -} +} // namespace ModuleBase #endif diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index 1241410e700..e7c9d4f5749 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -38,3 +38,20 @@ AddTest( LIBS ${math_libs} SOURCES complexmatrix_test.cpp ../complexmatrix.cpp ../matrix.cpp ) +AddTest( + TARGET base_realarray + SOURCES realarray_test.cpp ../realarray.cpp +) +AddTest( + TARGET base_intarray + SOURCES intarray_test.cpp ../intarray.cpp +) +AddTest( + TARGET base_vector3 + SOURCES vector3_test.cpp +) +AddTest( + TARGET base_mathzone + LIBS ${math_libs} + SOURCES mathzone_test.cpp ../mathzone.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../memory.cpp ../timer.cpp +) diff --git a/source/module_base/test/mathzone_test.cpp b/source/module_base/test/mathzone_test.cpp new file mode 100644 index 00000000000..1ad11dd16ff --- /dev/null +++ b/source/module_base/test/mathzone_test.cpp @@ -0,0 +1,86 @@ +#include "../mathzone.h" +#include "../matrix3.h" +#include "../vector3.h" +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +/************************************************ + * unit test of class Mathzone + ***********************************************/ + +/** + * - Tested Functions: + * - PointwiseProduct + * - return a vector, which is the pointwise + * - product of another two vectors of the same + * - length + * - Direct2Cartesian + * - change atomic coordinates from direct + * - to Cartesian + * - Cartesian2Direct + * - change atomic coordinates from Cartesian + * - to Direct + */ + +class MathzoneTest : public testing::Test +{ +protected: + double R11 = 3.68; double R12 = 0.00; double R13 = 0.00; + double R21 = 0.00; double R22 = 10.1; double R23 = 0.00; + double R31 = 0.00; double R32 = 0.00; double R33 = 26.7; + ModuleBase::Matrix3 lattice; + ModuleBase::Vector3 direct, cartesian; +}; + +TEST_F(MathzoneTest, PointwiseProduct) +{ + std::vector aa, bb, cc; + for(int i=0;i<10;i++) + { + aa.push_back(i*i); + bb.push_back(i*2); + } + cc = ModuleBase::Mathzone::Pointwise_Product(aa,bb); + for(int i=0;i<10;i++) + { + EXPECT_EQ(cc[i],i*i*i*2); + } +} + +TEST_F(MathzoneTest, Direct2Cartesian) +{ + direct.set(0.1,0.2,0.4); + cartesian.set(0.368,2.02,10.68); + ModuleBase::Vector3 cartnew; + ModuleBase::Mathzone::Direct_to_Cartesian(direct.x, + direct.y, + direct.z, + R11, R12, R13, + R21, R22, R23, + R31, R32, R33, + cartnew.x, + cartnew.y, + cartnew.z); + EXPECT_NEAR(cartnew.x,cartesian.x, 1e-15); + EXPECT_NEAR(cartnew.y,cartesian.y, 1e-15); + EXPECT_NEAR(cartnew.z,cartesian.z, 1e-15); +} + +TEST_F(MathzoneTest, Cartesian2Direct) +{ + direct.set(0.1,0.2,0.4); + cartesian.set(0.368,2.02,10.68); + ModuleBase::Vector3 directnew; + ModuleBase::Mathzone::Cartesian_to_Direct(cartesian.x, + cartesian.y, + cartesian.z, + R11, R12, R13, + R21, R22, R23, + R31, R32, R33, + directnew.x, + directnew.y, + directnew.z); + EXPECT_NEAR(directnew.x,direct.x, 1e-15); + EXPECT_NEAR(directnew.y,direct.y, 1e-15); + EXPECT_NEAR(directnew.z,direct.z, 1e-15); +} From f26107aecd9f9ec1df52c61bbddedcd83c0d40ec Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Sat, 12 Feb 2022 09:55:00 +0800 Subject: [PATCH 22/52] small changes --- source/module_orbital/test/ORB_unittest.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/source/module_orbital/test/ORB_unittest.h b/source/module_orbital/test/ORB_unittest.h index d817cf07e84..828b4890b6f 100644 --- a/source/module_orbital/test/ORB_unittest.h +++ b/source/module_orbital/test/ORB_unittest.h @@ -1,10 +1,10 @@ -#ifndef _ORBTEST_ -#define _ORBTEST_ +#ifndef _ORBUNITTEST_ +#define _ORBUNITTEST_ #include "gtest/gtest.h" -#include "../../../source/module_orbital/ORB_control.h" -#include "../../../source/module_base/global_function.h" -#include "../../../source/src_lcao/center2_orb-orb11.h" +#include "module_orbital/ORB_control.h" +#include "module_base/global_function.h" +#include "src_lcao/center2_orb-orb11.h" //#include "mock_center2.h" #include #include From a5d3e11a09392d5983ff05929b1b3d964c796679 Mon Sep 17 00:00:00 2001 From: xingliang Date: Mon, 14 Feb 2022 12:28:26 +0800 Subject: [PATCH 23/52] test: add the unit test of sph_bessel_recursive.h range: source/module_base/test --- source/module_base/test/CMakeLists.txt | 4 + .../test/sph_bessel_recursive_test.cpp | 84 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 source/module_base/test/sph_bessel_recursive_test.cpp diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index b7e9a9b7281..6a3a21017ee 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -42,4 +42,8 @@ AddTest( TARGET base_matrix LIBS ${math_libs} SOURCES matrix_test.cpp ../matrix.cpp +) +AddTest( + TARGET base_sph_bessel_recursive + SOURCES sph_bessel_recursive_test.cpp ../sph_bessel_recursive-d1.cpp ../sph_bessel_recursive-d2.cpp ) \ No newline at end of file diff --git a/source/module_base/test/sph_bessel_recursive_test.cpp b/source/module_base/test/sph_bessel_recursive_test.cpp new file mode 100644 index 00000000000..9355cb09031 --- /dev/null +++ b/source/module_base/test/sph_bessel_recursive_test.cpp @@ -0,0 +1,84 @@ +#include"../sph_bessel_recursive.h" +#include"gtest/gtest.h" + +#define threshold 1e-12 + +/************************************************ +* unit test of class Sph_Bessel_Recursive +***********************************************/ + +/** + * Note: this unit test try to ensure the invariance + * of the spherical Bessel produced by class Sph_Bessel_Recursive, + * and the reference results are produced by ModuleBase::Sph_Bessel_Recursive + * at 2022-1-25. + * + */ + +double mean(std::vector &vect) +{ + double meanv = 0.0; + + int totN = vect.size(); + for (int i=0; i< totN; ++i) {meanv += vect[i]/totN;} + + return meanv; +} + +TEST(SphBessel,D1) +{ + int lmax = 7; + int rmesh = 700; + double dx = 0.01; + + ModuleBase::Sph_Bessel_Recursive::D1 sphbesseld1; + sphbesseld1.set_dx(dx); + sphbesseld1.cal_jlx(lmax,rmesh); + std::vector> jlx = sphbesseld1.get_jlx(); + + ASSERT_EQ(jlx.size(),static_cast(lmax + 1)); + EXPECT_NEAR( mean(jlx[0])/0.2084468748396, 1.0, threshold); + EXPECT_NEAR( mean(jlx[1])/0.12951635180384, 1.0, threshold); + EXPECT_NEAR( mean(jlx[2])/0.124201140093879, 1.0, threshold); + EXPECT_NEAR( mean(jlx[3])/0.118268654505568, 1.0, threshold); + EXPECT_NEAR( mean(jlx[4])/0.0933871035384385, 1.0, threshold); + EXPECT_NEAR( mean(jlx[5])/0.0603800487910689, 1.0, threshold); + EXPECT_NEAR( mean(jlx[6])/0.0327117051555907, 1.0, threshold); + EXPECT_NEAR( mean(jlx[7])/0.0152155566653926, 1.0, threshold); +} + + +TEST(SphBessel,D2) +{ + int lmax = 7; + int rmesh = 700; + int kmesh = 800; + double dx = 0.0001; + + ModuleBase::Sph_Bessel_Recursive::D2 sphbesseld2; + sphbesseld2.set_dx(dx); + sphbesseld2.cal_jlx(lmax,rmesh,kmesh); + std::vector>> jlxd2 = sphbesseld2.get_jlx(); + std::vector> jlx(lmax+1); + + ASSERT_EQ(jlxd2.size(),static_cast(lmax + 1)); + + //calculate the mean of jlxd2[i][j] and assign to jlx[i][j] + for(int i=0; i Date: Tue, 15 Feb 2022 10:42:03 +0800 Subject: [PATCH 24/52] test: add the unit test and comments of math_sphbes.h range: source/module_base --- source/module_base/math_sphbes.h | 10 ++++- source/module_base/test/CMakeLists.txt | 4 ++ source/module_base/test/math_sphbes_test.cpp | 44 +++++++++++++++++++- 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/source/module_base/math_sphbes.h b/source/module_base/math_sphbes.h index 60245e95d72..78e6a228566 100644 --- a/source/module_base/math_sphbes.h +++ b/source/module_base/math_sphbes.h @@ -52,7 +52,15 @@ class Sphbes double *sjp ); - + /** + * @brief return num eigenvalues of spherical bessel function + * + * @param num [in] the number of eigenvalues + * @param l [in] angular number + * @param epsilon [in] the accuracy + * @param eigenvalue [out] the calculated eigenvalues + * @param rcut [in] the cutoff the radial function + */ static void Spherical_Bessel_Roots ( const int &num, diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index 4ddcc55e72a..ad36d6940bb 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -47,6 +47,10 @@ AddTest( TARGET base_sph_bessel_recursive SOURCES sph_bessel_recursive_test.cpp ../sph_bessel_recursive-d1.cpp ../sph_bessel_recursive-d2.cpp ) +AddTest( + TARGET base_math_sphbes + SOURCES math_sphbes_test.cpp ../math_sphbes.cpp ../timer.cpp +) AddTest( TARGET base_realarray SOURCES realarray_test.cpp ../realarray.cpp diff --git a/source/module_base/test/math_sphbes_test.cpp b/source/module_base/test/math_sphbes_test.cpp index 936e1a34245..40dc8343cb5 100644 --- a/source/module_base/test/math_sphbes_test.cpp +++ b/source/module_base/test/math_sphbes_test.cpp @@ -20,7 +20,9 @@ * and the reference results are produced by ABACUS * at 2022-1-27. * - * Tested function: Spherical_Bessel. + * Tested function: + * - Spherical_Bessel. + * - Spherical_Bessel_Roots * */ @@ -124,6 +126,46 @@ TEST_F(Sphbes,SphericalBessel) EXPECT_NEAR(mean(jl,msh)/0.015215556095798710851, 1.0,doublethreshold); } +TEST_F(Sphbes,SphericalBesselRoots) +{ + int neign = 100; + double **eign = new double*[8]; + for(int i=0;i<8;++i) + { + eign[i] = new double[neign]; + ModuleBase::Sphbes::Spherical_Bessel_Roots(neign,i,1.0e-12,eign[i],10.0); + } + + EXPECT_NEAR(eign[0][0]/0.31415926535899563188, 1.0,doublethreshold); + EXPECT_NEAR(eign[0][99]/31.415926535896932847, 1.0,doublethreshold); + EXPECT_NEAR(mean(eign[0],100)/15.865042900628463229, 1.0,doublethreshold); + EXPECT_NEAR(eign[1][0]/0.44934094579091843347, 1.0,doublethreshold); + EXPECT_NEAR(eign[1][99]/31.572689440204385392, 1.0,doublethreshold); + EXPECT_NEAR(mean(eign[1],100)/16.020655759558295017, 1.0,doublethreshold); + EXPECT_NEAR(eign[2][0]/0.57634591968946913276, 1.0,doublethreshold); + EXPECT_NEAR(eign[2][99]/31.729140298172534784, 1.0,doublethreshold); + EXPECT_NEAR(mean(eign[2],100)/16.175128483074864505, 1.0,doublethreshold); + EXPECT_NEAR(eign[3][0]/0.69879320005004752492, 1.0,doublethreshold); + EXPECT_NEAR(eign[3][99]/31.885283678838447941, 1.0,doublethreshold); + EXPECT_NEAR(mean(eign[3],100)/16.328616567969248763, 1.0,doublethreshold); + EXPECT_NEAR(eign[4][0]/0.81825614525711076741, 1.0,doublethreshold); + EXPECT_NEAR(eign[4][99]/32.041124042016576823, 1.0,doublethreshold); + EXPECT_NEAR(mean(eign[4],100)/16.481221742387987206, 1.0,doublethreshold); + EXPECT_NEAR(eign[5][0]/0.93558121110426506473, 1.0,doublethreshold); + EXPECT_NEAR(eign[5][99]/32.196665741899131774, 1.0,doublethreshold); + EXPECT_NEAR(mean(eign[5],100)/16.633019118735202113, 1.0,doublethreshold); + EXPECT_NEAR(eign[6][0]/1.051283540809391015, 1.0,doublethreshold); + EXPECT_NEAR(eign[6][99]/32.351913030537232885, 1.0,doublethreshold); + EXPECT_NEAR(mean(eign[6],100)/16.784067905062840964, 1.0,doublethreshold); + EXPECT_NEAR(eign[7][0]/1.1657032192516516567, 1.0,doublethreshold); + EXPECT_NEAR(eign[7][99]/32.506870061157627561, 1.0,doublethreshold); + EXPECT_NEAR(mean(eign[7],100)/16.934416735327332049, 1.0,doublethreshold); + + for(int i=0;i<8;++i) delete [] eign[i]; + delete [] eign; +} + + int main(int argc, char **argv) { #ifdef __MPI From aca89f46c7b33378dd57c6e7ab2321a6f7e14caa Mon Sep 17 00:00:00 2001 From: xingliang Date: Tue, 15 Feb 2022 10:45:34 +0800 Subject: [PATCH 25/52] fix the delete of pointer r jl --- source/module_base/test/math_sphbes_test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/module_base/test/math_sphbes_test.cpp b/source/module_base/test/math_sphbes_test.cpp index 40dc8343cb5..fc556028550 100644 --- a/source/module_base/test/math_sphbes_test.cpp +++ b/source/module_base/test/math_sphbes_test.cpp @@ -57,8 +57,8 @@ class Sphbes : public testing::Test void TearDown() { - delete r; - delete jl; + delete [] r; + delete [] jl; } }; From 1ce08d00d7b6a5764493101384fa937368259b2f Mon Sep 17 00:00:00 2001 From: xingliang Date: Tue, 15 Feb 2022 14:13:08 +0800 Subject: [PATCH 26/52] loose the threshold and modify the comments --- source/module_base/math_sphbes.h | 8 ++++---- source/module_base/test/math_sphbes_test.cpp | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/source/module_base/math_sphbes.h b/source/module_base/math_sphbes.h index 78e6a228566..f7257b2dc3e 100644 --- a/source/module_base/math_sphbes.h +++ b/source/module_base/math_sphbes.h @@ -18,10 +18,10 @@ class Sphbes * @brief spherical bessel * * @param msh [in] number of grid points - * @param r [in] radial grid (1:msh) + * @param r [in] radial grid * @param q [in] k_radial * @param l [in] angular momentum - * @param jl [out] jl(1:msh) spherical bessel function + * @param jl [out] jl spherical bessel function */ static void Spherical_Bessel ( @@ -36,10 +36,10 @@ class Sphbes * @brief spherical bessel * * @param msh [in] number of grid points - * @param r [in] radial grid (1:msh) + * @param r [in] radial grid * @param q [in] k_radial * @param l [in] angular momentum - * @param jl [out] jl(1:msh) spherical bessel function + * @param jl [out] jl spherical bessel function * @param sjp [out] sjp[i] is assigned to be 1.0. i < msh. */ static void Spherical_Bessel diff --git a/source/module_base/test/math_sphbes_test.cpp b/source/module_base/test/math_sphbes_test.cpp index fc556028550..1b0dc208507 100644 --- a/source/module_base/test/math_sphbes_test.cpp +++ b/source/module_base/test/math_sphbes_test.cpp @@ -7,7 +7,7 @@ #include"gtest/gtest.h" -#define doublethreshold 1e-12 +#define doublethreshold 1e-7 /************************************************ From bbc471e450b389f1af64397e4d6fad01d5e42db3 Mon Sep 17 00:00:00 2001 From: ouqi0711 Date: Tue, 15 Feb 2022 15:15:45 +0800 Subject: [PATCH 27/52] modified the recommended values of exx related input parameters. --- doc/input-main.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/input-main.md b/doc/input-main.md index 16e79bfdda6..b84290401ff 100644 --- a/doc/input-main.md +++ b/doc/input-main.md @@ -992,7 +992,7 @@ This part of variables are relevant when using hybrid functionals - exx_hybrid_type - *Type*: String - - *Description*: Type of hybrid functional used. Options are "hf" (pure Hartree-Fock), "pbe0"(PBE0), "hse" (Note: in order to use HSE functional, LIBXC is required). + - *Description*: Type of hybrid functional used. Options are "hf" (pure Hartree-Fock), "pbe0"(PBE0), "hse" (Note: in order to use HSE functional, LIBXC is required). Note also that HSE has been tested while PBE0 has NOT been fully tested yet. If set to "no", then no hybrid functional is used (i.e.,Fock exchange is not included.) @@ -1023,7 +1023,7 @@ adial integration for pseudopotentials, in Bohr. - exx_pca_threshold - *Type*: Real - - *Description*: To accelerate the evaluation of four-center integrals (ik|jl), the product of atomic orbitals are expanded in the basis of auxiliary basis functions (ABF): φiφj~CkijPk. The size of the ABF (i.e. number of Pk) is reduced using principal component analysis. When a large PCA threshold is used, the number of ABF will be reduced, hence the calculations becomes faster. However this comes at the cost of computational accuracy. A relatively safe choice of the value is 1d-3. + - *Description*: To accelerate the evaluation of four-center integrals (ik|jl), the product of atomic orbitals are expanded in the basis of auxiliary basis functions (ABF): φiφj~CkijPk. The size of the ABF (i.e. number of Pk) is reduced using principal component analysis. When a large PCA threshold is used, the number of ABF will be reduced, hence the calculations becomes faster. However this comes at the cost of computational accuracy. A relatively safe choice of the value is 1d-4. - *Default*: 0 [back to top](#input-file) @@ -1044,21 +1044,21 @@ adial integration for pseudopotentials, in Bohr. - exx_dm_threshold - *Type*: Real - - *Description*: The Fock exchange can be expressed as Σk,l(ik|jl)Dkl where D is the density matrix. Smaller values of the density matrix can be truncated to accelerate calculation. The larger the threshold is, the faster the calculation and the lower the accuracy. A relatively safe choice of the value is 1d-3. + - *Description*: The Fock exchange can be expressed as Σk,l(ik|jl)Dkl where D is the density matrix. Smaller values of the density matrix can be truncated to accelerate calculation. The larger the threshold is, the faster the calculation and the lower the accuracy. A relatively safe choice of the value is 1d-4. - *Default*: 0 [back to top](#input-file) - exx_schwarz_threshold - *Type*: Real - - *Description*: In practice the four-center integrals are sparse, and using Cauchy-Schwartz inequality, we can find an upper bound of each integral before carrying out explicit evaluations. Those that are smaller than exx_schwarz_threshold will be truncated. The larger the threshold is, the faster the calculation and the lower the accuracy. A relatively safe choice of the value is 1d-4. + - *Description*: In practice the four-center integrals are sparse, and using Cauchy-Schwartz inequality, we can find an upper bound of each integral before carrying out explicit evaluations. Those that are smaller than exx_schwarz_threshold will be truncated. The larger the threshold is, the faster the calculation and the lower the accuracy. A relatively safe choice of the value is 1d-5. - *Default*: 0 [back to top](#input-file) - exx_cauchy_threshold - *Type*: Real - - *Description*: In practice the Fock exchange matrix is sparse, and using Cauchy-Schwartz inequality, we can find an upper bound of each matrix element before carrying out explicit evaluations. Those that are smaller than exx_cauchy_threshold will be truncated. The larger the threshold is, the faster the calculation and the lower the accuracy. A relatively safe choice of the value is 1d-6. + - *Description*: In practice the Fock exchange matrix is sparse, and using Cauchy-Schwartz inequality, we can find an upper bound of each matrix element before carrying out explicit evaluations. Those that are smaller than exx_cauchy_threshold will be truncated. The larger the threshold is, the faster the calculation and the lower the accuracy. A relatively safe choice of the value is 1d-7. - *Default*: 0 [back to top](#input-file) From e67cec0a64e567e6148d1ffd7f0e70d26ca35fcc Mon Sep 17 00:00:00 2001 From: ouqi0711 Date: Tue, 15 Feb 2022 15:26:01 +0800 Subject: [PATCH 28/52] added a note of parallel job limitations for exx jobs --- doc/input-main.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/input-main.md b/doc/input-main.md index b84290401ff..ed71d83f3f3 100644 --- a/doc/input-main.md +++ b/doc/input-main.md @@ -992,7 +992,7 @@ This part of variables are relevant when using hybrid functionals - exx_hybrid_type - *Type*: String - - *Description*: Type of hybrid functional used. Options are "hf" (pure Hartree-Fock), "pbe0"(PBE0), "hse" (Note: in order to use HSE functional, LIBXC is required). Note also that HSE has been tested while PBE0 has NOT been fully tested yet. + - *Description*: Type of hybrid functional used. Options are "hf" (pure Hartree-Fock), "pbe0"(PBE0), "hse" (Note: in order to use HSE functional, LIBXC is required). Note also that HSE has been tested while PBE0 has NOT been fully tested yet, and the maxmum parallel cpus for running exx is Nx(N+1)/2, with N being the number of atoms. If set to "no", then no hybrid functional is used (i.e.,Fock exchange is not included.) From 057fe8f0fc3b8259c884c3463e0135065961ae00 Mon Sep 17 00:00:00 2001 From: xingliang Date: Thu, 17 Feb 2022 11:08:50 +0800 Subject: [PATCH 29/52] test: add the unit test and comments for math_polyint.h range: source/module_base/ --- source/module_base/math_polyint.h | 66 ++++++++- source/module_base/test/CMakeLists.txt | 4 + source/module_base/test/math_polyint_test.cpp | 134 ++++++++++++++++++ 3 files changed, 199 insertions(+), 5 deletions(-) create mode 100644 source/module_base/test/math_polyint_test.cpp diff --git a/source/module_base/math_polyint.h b/source/module_base/math_polyint.h index 0cc13f8f831..b15f50678c3 100644 --- a/source/module_base/math_polyint.h +++ b/source/module_base/math_polyint.h @@ -18,6 +18,19 @@ class PolyInt //======================================================== // Polynomial_Interpolation //======================================================== + + /** + * @brief Lagrange interpolation + * + * @param table [in] three dimension matrix, the data in 3rd dimension is used to do prediction + * @param dim1 [in] index of 1st dimension of table/y + * @param dim2 [in] index of 2nd dimension of table/y + * @param y [out] three dimension matrix to store the predicted value + * @param dim_y [in] index of 3rd dimension of y to store predicted value + * @param table_length [in] length of 3rd dimension of table + * @param table_interval [in] interval of 3rd dimension of table + * @param x [in] the position in 3rd dimension to be predicted + */ static void Polynomial_Interpolation ( const ModuleBase::realArray &table, @@ -30,6 +43,17 @@ class PolyInt const double &x ); + /** + * @brief Lagrange interpolation + * + * @param table [in] three dimension matrix, the data in 3rd dimension is used to do prediction + * @param dim1 [in] index of 1st dimension of table + * @param dim2 [in] index of 2nd dimension of table + * @param table_length [in] length of 3rd dimension of table + * @param table_interval [in] interval of 3rd dimension of table + * @param x [in] the position in 3rd dimension to be predicted + * @return double the predicted value + */ static double Polynomial_Interpolation ( const ModuleBase::realArray &table, @@ -37,10 +61,24 @@ class PolyInt const int &dim2, const int &table_length, const double &table_interval, - const double &x // input value + const double &x ); - static double Polynomial_Interpolation // pengfei Li 2018-3-23 + /** + * @brief Lagrange interpolation + * + * @param table [in] four dimension matrix, the data in 4th dimension is used to do prediction + * @param dim1 [in] index of 1st dimension of table + * @param dim2 [in] index of 2nd dimension of table + * @param dim3 [in] index of 3rd dimension of table + * @param table_length [in] length of 4th dimension of table + * @param table_interval [in] interval of 4th dimension of table + * @param x [in] the position in 4th dimension to be predicted + * @return double the predicted value + * @author pengfei Li + * @date 2018-3-23 + */ + static double Polynomial_Interpolation ( const ModuleBase::realArray &table, const int &dim1, @@ -48,23 +86,41 @@ class PolyInt const int &dim3, const int &table_length, const double &table_interval, - const double &x // input value + const double &x ); + /** + * @brief Lagrange interpolation + * + * @param table [in] the data used to do prediction + * @param table_length [in] length of table + * @param table_interval [in] interval of table + * @param x [in] the position to be predicted + * @return double the predicted value + */ static double Polynomial_Interpolation ( const double *table, const int &table_length, const double &table_interval, - const double &x // input value + const double &x ); + /** + * @brief Lagrange interpolation + * + * @param xpoint [in] array of postion + * @param ypoint [in] array of data to do prediction + * @param table_length [in] length of xpoint + * @param x [in] position to be predicted + * @return double predicted value + */ static double Polynomial_Interpolation_xy ( const double *xpoint, const double *ypoint, const int table_length, - const double &x // input value + const double &x ); }; diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index ad36d6940bb..c7cbd3c4380 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -67,4 +67,8 @@ AddTest( TARGET base_mathzone LIBS ${math_libs} SOURCES mathzone_test.cpp ../mathzone.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../memory.cpp ../timer.cpp +) +AddTest( + TARGET base_math_polyint + SOURCES math_polyint_test.cpp ../math_polyint.cpp ../realarray.cpp ../timer.cpp ) \ No newline at end of file diff --git a/source/module_base/test/math_polyint_test.cpp b/source/module_base/test/math_polyint_test.cpp new file mode 100644 index 00000000000..1dc3d7d852c --- /dev/null +++ b/source/module_base/test/math_polyint_test.cpp @@ -0,0 +1,134 @@ +#include"../math_polyint.h" +#include"gtest/gtest.h" +#include"../realarray.h" +#include + +#define doublethreshold 1e-9 + +/************************************************ +* unit test of class PolyInt +***********************************************/ + +/** + * This unit test is to verify the accuracy of + * interpolation method on the function sin(x)/x + * with a interval of 0.01. + * sin(x)/x is one of the solution of spherical bessel + * function when l=0. + * + * - Tested function: + * - 4 types of Polynomial_Interpolation + * - Polynomial_Interpolation_xy + */ + + +class bessell0 : public testing::Test +{ + protected: + + int TableLength = 400; + double interval = 0.01; + ModuleBase::realArray table3,table4; + ModuleBase::realArray y3; + double *tablex = new double[TableLength]; + double *tabley = new double[TableLength]; + + double Func(double x) {return sin(x)/x;} + + void SetUp() + { + table3.create(1,1,TableLength); + table4.create(1,1,1,TableLength); + y3.create(1,1,TableLength); + + for(int i=1;i Date: Thu, 17 Feb 2022 13:11:40 +0800 Subject: [PATCH 30/52] modify the name of Func to sinc put the "new" action of tablex and tabley in SetUp --- source/module_base/test/math_polyint_test.cpp | 64 ++++++++++--------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/source/module_base/test/math_polyint_test.cpp b/source/module_base/test/math_polyint_test.cpp index 1dc3d7d852c..bb5a5937e91 100644 --- a/source/module_base/test/math_polyint_test.cpp +++ b/source/module_base/test/math_polyint_test.cpp @@ -30,23 +30,25 @@ class bessell0 : public testing::Test double interval = 0.01; ModuleBase::realArray table3,table4; ModuleBase::realArray y3; - double *tablex = new double[TableLength]; - double *tabley = new double[TableLength]; + double *tablex; + double *tabley; - double Func(double x) {return sin(x)/x;} + double sinc(double x) {return sin(x)/x;} void SetUp() { + tablex = new double[TableLength]; + tabley = new double[TableLength]; table3.create(1,1,TableLength); table4.create(1,1,1,TableLength); y3.create(1,1,TableLength); for(int i=1;i Date: Thu, 17 Feb 2022 15:30:00 +0800 Subject: [PATCH 31/52] test:add the unittest and comments of math_ylmreal.h range: source/module_base --- source/module_base/math_ylmreal.h | 45 +++- source/module_base/test/CMakeLists.txt | 5 + source/module_base/test/math_ylmreal_test.cpp | 204 ++++++++++++++++++ 3 files changed, 244 insertions(+), 10 deletions(-) create mode 100644 source/module_base/test/math_ylmreal_test.cpp diff --git a/source/module_base/math_ylmreal.h b/source/module_base/math_ylmreal.h index 362f2dfdd82..c051ef95ca3 100644 --- a/source/module_base/math_ylmreal.h +++ b/source/module_base/math_ylmreal.h @@ -14,29 +14,54 @@ class YlmReal YlmReal(); ~YlmReal(); + /** + * @brief spherical harmonic function (real form) an array of vectors + * + * @param lmax2 [in] lmax2 = (lmax + 1)^2 ; lmax = angular quantum number + * @param ng [in] the number of vectors + * @param g [in] an array of vectors + * @param ylm [out] Ylm; column index represent vector, row index represent Y00, Y10, Y11, Y1-1, Y20,Y21,Y2-1,Y22.Y2-2,...; + */ static void Ylm_Real ( - const int lmax2, // lmax2 = (lmax+1)^2 - const int ng, // - const ModuleBase::Vector3 *g, // g_cartesian_vec(x,y,z) - matrix &ylm // output + const int lmax2, + const int ng, + const ModuleBase::Vector3 *g, + matrix &ylm ); + /** + * @brief spherical harmonic function (Herglotz generating form) of an array of vectors + * + * @param lmax2 [in] lmax2 = (lmax + 1)^2 ; lmax = angular quantum number + * @param ng [in] the number of vectors + * @param g [in] an array of vectors + * @param ylm [out] Ylm; column index represent vector, row index represent Y00, Y10, Y11, Y1-1, Y20,Y21,Y2-1,Y22.Y2-2,...; + */ static void Ylm_Real2 ( - const int lmax2, // lmax2 = (lmax+1)^2 - const int ng, // - const ModuleBase::Vector3 *g, // g_cartesian_vec(x,y,z) - matrix &ylm // output + const int lmax2, + const int ng, + const ModuleBase::Vector3 *g, + matrix &ylm ); + /** + * @brief spherical harmonic function (Herglotz generating form) of a vector + * + * @param lmax [in] maximum angular quantum number + * @param x [in] x part of the vector + * @param y [in] y part of the vector + * @param z [in] z part of the vector + * @param rly [in] Ylm, Y00, Y10, Y11, Y1-1, Y20,Y21,Y2-1,Y22.Y2-2,... + */ static void rlylm ( const int lmax, const double& x, const double& y, - const double& z, // g_cartesian_vec(x,y,z) - double* rly // output + const double& z, + double* rly ); private: diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index c7cbd3c4380..ddfa687934d 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -71,4 +71,9 @@ AddTest( AddTest( TARGET base_math_polyint SOURCES math_polyint_test.cpp ../math_polyint.cpp ../realarray.cpp ../timer.cpp +) +AddTest( + TARGET base_math_ylmreal + LIBS ${math_libs} + SOURCES math_ylmreal_test.cpp ../math_ylmreal.cpp ../realarray.cpp ../timer.cpp ../matrix.cpp ) \ No newline at end of file diff --git a/source/module_base/test/math_ylmreal_test.cpp b/source/module_base/test/math_ylmreal_test.cpp new file mode 100644 index 00000000000..455dfb7e95f --- /dev/null +++ b/source/module_base/test/math_ylmreal_test.cpp @@ -0,0 +1,204 @@ +#include"../math_ylmreal.h" +#include"../vector3.h" +#include"../matrix.h" +#include"gtest/gtest.h" + +#define PI 3.141592653589793238462643383279502884197169399 +#define doublethreshold 1e-12 + +/************************************************ +* unit test of class YlmReal and Ylm +***********************************************/ + +/** + * For lmax <5 cases, the reference values are calculated by the formula from + * https://formulasearchengine.com/wiki/Table_of_spherical_harmonics. Note, these + * formula lack of the Condon–Shortley phase (-1)^m, and in this unit test, item + * (-1)^m is multiplied. + * For lmax >=5, the reference values are calculated by YlmReal::Ylm_Real. + * + * - Tested functions of class YlmReal + * - Ylm_Real + * - Ylm_Real2 + * - rlylm + */ + + + +//mock functions of WARNING_QUIT and WARNING +namespace ModuleBase +{ + void WARNING_QUIT(const std::string &file,const std::string &description) {return ;} + void WARNING(const std::string &file,const std::string &description) {return ;} +} + + +class YlmRealTest : public testing::Test +{ + protected: + + int lmax = 7; + int ng = 4; //test the 4 selected points on the sphere + int nylm ; // total Ylm number; + ModuleBase::matrix ylm; + ModuleBase::Vector3 *g; + double *ref; + double *rly; + + //Ylm function + //https://formulasearchengine.com/wiki/Table_of_spherical_harmonics + //multipy the Condon–Shortley phase (-1)^m + inline double norm(const double &x, const double &y, const double &z) {return sqrt(x*x + y*y + z*z);} + double y00(const double &x, const double &y, const double &z) {return 1.0/2.0/sqrt(PI);} + double y10(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return sqrt(3.0/(4.0*PI)) * z / r;} + double y11(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*sqrt(3.0/(4.*PI)) * x / r;} + double y1m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*sqrt(3./(4.*PI)) * y / r;} // y1m1 means Y1,-1 + double y20(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(5./PI) * (-1.*x*x - y*y + 2.*z*z) / (r*r);} + double y21(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./2. * sqrt(15./PI) * (z*x) / (r*r);} + double y2m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./2. * sqrt(15./PI) * (z*y) / (r*r);} + double y22(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(15./PI) * (x*x - y*y) / (r*r);} + double y2m2(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./2. * sqrt(15./PI) * (x*y) / (r*r);} + double y30(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(7./PI) * z*(2.*z*z-3.*x*x-3.*y*y) / (r*r*r);} + double y31(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(21./2./PI) * x*(4.*z*z-x*x-y*y) / (r*r*r);} + double y3m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(21./2./PI) * y*(4.*z*z-x*x-y*y) / (r*r*r);} + double y32(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(105./PI) * (x*x - y*y)*z / (r*r*r);} + double y3m2(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./2. * sqrt(105./PI) * x*y*z / (r*r*r);} + double y33(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(35./2./PI) * x*(x*x - 3.*y*y) / (r*r*r);} + double y3m3(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(35./2./PI) * y*(3.*x*x - y*y) / (r*r*r);} + double y40(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./16.*sqrt(1./PI) * (35.*z*z*z*z - 30.*z*z*r*r + 3*r*r*r*r) / (r*r*r*r);} + double y41(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(5./2./PI) * x*z*(7.*z*z - 3*r*r) / (r*r*r*r);} + double y4m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(5./2./PI) * y*z*(7.*z*z - 3.*r*r) / (r*r*r*r);} + double y42(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./8.*sqrt(5./PI) * (x*x-y*y)*(7.*z*z-r*r) / (r*r*r*r);} + double y4m2(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./4.*sqrt(5./PI) * x*y*(7.*z*z - r*r) / (r*r*r*r);} + double y43(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(35./2./PI) * x*z*(x*x - 3.*y*y) / (r*r*r*r);} + double y4m3(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(35./2./PI) * y*z*(3.*x*x - y*y) / (r*r*r*r);} + double y44(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./16.*sqrt(35./PI) * (x*x*(x*x - 3.*y*y) - y*y*(3.*x*x-y*y)) / (r*r*r*r);} + double y4m4(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./4.*sqrt(35./PI) * x*y*(x*x - y*y) / (r*r*r*r);} + + void SetUp() + { + nylm = (lmax + 1) * (lmax + 1); + ylm.create(nylm,ng); + g = new ModuleBase::Vector3[ng]; + g[0].set(1.0,0.0,0.0); + g[1].set(0.0,1.0,0.0); + g[2].set(0.0,0.0,1.0); + g[3].set(-1.0,-1.0,-1.0); + + rly = new double[nylm]; + ref = new double[nylm*ng]{ + y00(g[0].x, g[0].y, g[0].z), y00(g[1].x, g[1].y, g[1].z), y00(g[2].x, g[2].y, g[2].z), y00(g[3].x, g[3].y, g[3].z), + y10(g[0].x, g[0].y, g[0].z), y10(g[1].x, g[1].y, g[1].z), y10(g[2].x, g[2].y, g[2].z), y10(g[3].x, g[3].y, g[3].z), + y11(g[0].x, g[0].y, g[0].z), y11(g[1].x, g[1].y, g[1].z), y11(g[2].x, g[2].y, g[2].z), y11(g[3].x, g[3].y, g[3].z), + y1m1(g[0].x, g[0].y, g[0].z), y1m1(g[1].x, g[1].y, g[1].z), y1m1(g[2].x, g[2].y, g[2].z), y1m1(g[3].x, g[3].y, g[3].z), + y20(g[0].x, g[0].y, g[0].z), y20(g[1].x, g[1].y, g[1].z), y20(g[2].x, g[2].y, g[2].z), y20(g[3].x, g[3].y, g[3].z), + y21(g[0].x, g[0].y, g[0].z), y21(g[1].x, g[1].y, g[1].z), y21(g[2].x, g[2].y, g[2].z), y21(g[3].x, g[3].y, g[3].z), + y2m1(g[0].x, g[0].y, g[0].z), y2m1(g[1].x, g[1].y, g[1].z), y2m1(g[2].x, g[2].y, g[2].z), y2m1(g[3].x, g[3].y, g[3].z), + y22(g[0].x, g[0].y, g[0].z), y22(g[1].x, g[1].y, g[1].z), y22(g[2].x, g[2].y, g[2].z), y22(g[3].x, g[3].y, g[3].z), + y2m2(g[0].x, g[0].y, g[0].z), y2m2(g[1].x, g[1].y, g[1].z), y2m2(g[2].x, g[2].y, g[2].z), y2m2(g[3].x, g[3].y, g[3].z), + y30(g[0].x, g[0].y, g[0].z), y30(g[1].x, g[1].y, g[1].z), y30(g[2].x, g[2].y, g[2].z), y30(g[3].x, g[3].y, g[3].z), + y31(g[0].x, g[0].y, g[0].z), y31(g[1].x, g[1].y, g[1].z), y31(g[2].x, g[2].y, g[2].z), y31(g[3].x, g[3].y, g[3].z), + y3m1(g[0].x, g[0].y, g[0].z), y3m1(g[1].x, g[1].y, g[1].z), y3m1(g[2].x, g[2].y, g[2].z), y3m1(g[3].x, g[3].y, g[3].z), + y32(g[0].x, g[0].y, g[0].z), y32(g[1].x, g[1].y, g[1].z), y32(g[2].x, g[2].y, g[2].z), y32(g[3].x, g[3].y, g[3].z), + y3m2(g[0].x, g[0].y, g[0].z), y3m2(g[1].x, g[1].y, g[1].z), y3m2(g[2].x, g[2].y, g[2].z), y3m2(g[3].x, g[3].y, g[3].z), + y33(g[0].x, g[0].y, g[0].z), y33(g[1].x, g[1].y, g[1].z), y33(g[2].x, g[2].y, g[2].z), y33(g[3].x, g[3].y, g[3].z), + y3m3(g[0].x, g[0].y, g[0].z), y3m3(g[1].x, g[1].y, g[1].z), y3m3(g[2].x, g[2].y, g[2].z), y3m3(g[3].x, g[3].y, g[3].z), + y40(g[0].x, g[0].y, g[0].z), y40(g[1].x, g[1].y, g[1].z), y40(g[2].x, g[2].y, g[2].z), y40(g[3].x, g[3].y, g[3].z), + y41(g[0].x, g[0].y, g[0].z), y41(g[1].x, g[1].y, g[1].z), y41(g[2].x, g[2].y, g[2].z), y41(g[3].x, g[3].y, g[3].z), + y4m1(g[0].x, g[0].y, g[0].z), y4m1(g[1].x, g[1].y, g[1].z), y4m1(g[2].x, g[2].y, g[2].z), y4m1(g[3].x, g[3].y, g[3].z), + y42(g[0].x, g[0].y, g[0].z), y42(g[1].x, g[1].y, g[1].z), y42(g[2].x, g[2].y, g[2].z), y42(g[3].x, g[3].y, g[3].z), + y4m2(g[0].x, g[0].y, g[0].z), y4m2(g[1].x, g[1].y, g[1].z), y4m2(g[2].x, g[2].y, g[2].z), y4m2(g[3].x, g[3].y, g[3].z), + y43(g[0].x, g[0].y, g[0].z), y43(g[1].x, g[1].y, g[1].z), y43(g[2].x, g[2].y, g[2].z), y43(g[3].x, g[3].y, g[3].z), + y4m3(g[0].x, g[0].y, g[0].z), y4m3(g[1].x, g[1].y, g[1].z), y4m3(g[2].x, g[2].y, g[2].z), y4m3(g[3].x, g[3].y, g[3].z), + y44(g[0].x, g[0].y, g[0].z), y44(g[1].x, g[1].y, g[1].z), y44(g[2].x, g[2].y, g[2].z), y44(g[3].x, g[3].y, g[3].z), + y4m4(g[0].x, g[0].y, g[0].z), y4m4(g[1].x, g[1].y, g[1].z), y4m4(g[2].x, g[2].y, g[2].z), y4m4(g[3].x, g[3].y, g[3].z), + 0.000000000000000, 0.000000000000000, 0.935602579627389, 0.090028400200397, + -0.452946651195697, -0.000000000000000, -0.000000000000000, -0.348678494661834, + -0.000000000000000, -0.452946651195697, -0.000000000000000, -0.348678494661834, + -0.000000000000000, 0.000000000000000, 0.000000000000000, -0.000000000000000, + -0.000000000000000, -0.000000000000000, 0.000000000000000, -0.000000000000000, + 0.489238299435250, 0.000000000000000, -0.000000000000000, -0.376615818502422, + 0.000000000000000, -0.489238299435250, -0.000000000000000, 0.376615818502422, + 0.000000000000000, 0.000000000000000, 0.000000000000000, 0.532615198330370, + 0.000000000000000, 0.000000000000000, 0.000000000000000, -0.000000000000000, + -0.656382056840170, -0.000000000000000, -0.000000000000000, -0.168427714314628, + -0.000000000000000, -0.656382056840170, -0.000000000000000, -0.168427714314628, + -0.317846011338142, -0.317846011338142, 1.017107236282055, 0.226023830284901, + -0.000000000000000, -0.000000000000000, -0.000000000000000, 0.258942827786103, + -0.000000000000000, -0.000000000000000, -0.000000000000000, 0.258942827786103, + 0.460602629757462, -0.460602629757462, 0.000000000000000, -0.000000000000000, + 0.000000000000000, 0.000000000000000, 0.000000000000000, -0.409424559784410, + -0.000000000000000, -0.000000000000000, -0.000000000000000, 0.136474853261470, + -0.000000000000000, 0.000000000000000, -0.000000000000000, -0.136474853261470, + -0.504564900728724, -0.504564900728724, 0.000000000000000, -0.598002845308118, + -0.000000000000000, -0.000000000000000, 0.000000000000000, 0.000000000000000, + -0.000000000000000, -0.000000000000000, -0.000000000000000, 0.350610246256556, + -0.000000000000000, -0.000000000000000, -0.000000000000000, 0.350610246256556, + 0.683184105191914, -0.683184105191914, 0.000000000000000, -0.000000000000000, + 0.000000000000000, 0.000000000000000, 0.000000000000000, -0.202424920056864, + 0.000000000000000, 0.000000000000000, 1.092548430592079, -0.350435072502801, + 0.451658037912587, 0.000000000000000, -0.000000000000000, 0.046358202625865, + 0.000000000000000, 0.451658037912587, -0.000000000000000, 0.046358202625865, + 0.000000000000000, -0.000000000000000, 0.000000000000000, 0.000000000000000, + 0.000000000000000, 0.000000000000000, 0.000000000000000, 0.492067081245654, + -0.469376801586882, -0.000000000000000, -0.000000000000000, 0.187354445356332, + -0.000000000000000, 0.469376801586882, -0.000000000000000, -0.187354445356332, + 0.000000000000000, 0.000000000000000, 0.000000000000000, 0.355076798886913, + 0.000000000000000, 0.000000000000000, 0.000000000000000, -0.000000000000000, + 0.518915578720260, 0.000000000000000, -0.000000000000000, -0.443845998608641, + 0.000000000000000, 0.518915578720260, -0.000000000000000, -0.443845998608641, + 0.000000000000000, -0.000000000000000, 0.000000000000000, 0.000000000000000, + 0.000000000000000, 0.000000000000000, 0.000000000000000, 0.452635881587108, + -0.707162732524596, 0.000000000000000, -0.000000000000000, 0.120972027847095, + -0.000000000000000, 0.707162732524596, -0.000000000000000, -0.120972027847095 + } ; + + + } + + void TearDown() + { + delete [] g; + delete [] ref; + delete [] rly; + } +}; + +TEST_F(YlmRealTest,YlmReal) +{ + ModuleBase::YlmReal::Ylm_Real(nylm,ng,g,ylm); + for(int i=0;i= 2) - { - if(in == il) - { - c_ln_c[il][in] = -c_ln_c[il-2][in-2]; - c_ln_s[il][in] = -c_ln_s[il-2][in-2]; - } - else - { - c_ln_c[il][in] = (2*il-1)*c_ln_c[il-1][in] - c_ln_c[il-2][in-2]; - c_ln_s[il][in] = (2*il-1)*c_ln_s[il-1][in] - c_ln_s[il-2][in-2]; - } - } - else - { - //in = 0 or 1, but here il = 2, so in != il - c_ln_c[il][in] = (2*il-1)*c_ln_c[il-1][in]; - c_ln_s[il][in] = (2*il-1)*c_ln_s[il-1][in]; - } - } - } - - ModuleBase::timer::tick("Mathzone_Add1","expand_coef_jlx"); - return; -} - -void Mathzone_Add1::Spherical_Bessel -( - const int &msh, //number of grid points - const double *r,//radial grid - const double &q, // - const int &l, //angular momentum - double *sj, //jl(1:msh) = j_l(q*r(i)),spherical bessel function - double *sjp -) -{ - ModuleBase::timer::tick ("Mathzone_Add1","Spherical_Bessel"); - - assert (l <= Mathzone_Add1::sph_lmax); - - //creat coefficients - if (!flag_jlx_expand_coef) - { - Mathzone_Add1::expand_coef_jlx (); - flag_jlx_expand_coef = true; - } - - //epsilon - const double eps = 1.0E-10; - - /****************************************************************** - jlx = \sum_{n=0}^{l}(c_{ln}^{s}sin(x) + c_{ln}^{c}cos(x))/x^{l-n+1} - jldx = \sum_{n=0}^{l}(c_{ln}^{s}sin(x) + c_{ln}^{c}cos(x))/x^{l-n+1} - *******************************************************************/ - for (int ir = 0; ir < msh; ir++) - { - double qr = q * r[ir]; - - //judge qr approx 0 - if (fabs(qr) < eps) - { - if (l == 0) sj[ir] = 1.0; - else sj[ir] = 0.0; - - if (l == 1) sjp[ir] = 1.0 / 3.0; - else sjp[ir] = 0.0; - } - else - { - sj[ir] = 0.0; - sjp[ir] = 0.0; - - double lqr = pow (qr, l); - - //divided fac - double xqr = 1.0; - - for (int n = 0; n <= l; n++) - { - double com1 = (c_ln_s[l][n] * sin(qr) + c_ln_c[l][n] * cos(qr)) * xqr; - double com2 = (c_ln_s[l][n] * cos(qr) - c_ln_c[l][n] * sin(qr)) * xqr * qr - + (c_ln_s[l][n] * sin(qr) + c_ln_c[l][n] * cos(qr)) * n * xqr; - - sj[ir] += com1; - sjp[ir] += (com2 - l*com1); - - xqr *= qr; - } - - sj[ir] /= lqr; - sj[ir] /= (lqr * qr); - } - } - - ModuleBase::timer::tick ("Mathzone_Add1","Spherical_Bessel"); - - return; - -} - -double Mathzone_Add1::uni_simpson -( - const double* func, - const int& mshr, - const double& dr -) -{ - ModuleBase::timer::tick("Mathzone_Add1","uni_simpson"); - - assert(mshr >= 3); - - int ir=0; - int idx=0; - int msh_left=0; - double sum = 0.0; - - //f(a) - sum += func[0]; - - //simpson 3/8 rule - if(mshr % 2 == 0) - { - //simpson 3/8 rule - sum += 9.0 / 8.0 * (func[mshr-1] + 3.0*(func[mshr-2] + func[mshr-3]) + func[mshr-4]); - msh_left = mshr - 3; - } - else - { - msh_left = mshr; - } - - //f(b) - sum += func[msh_left-1]; - - //points left - for(ir = 0; ir < msh_left / 2; ir++) - { - idx = 2*ir+1; - sum += 4.0 * func[idx]; - } - - for(ir = 1; ir < msh_left / 2; ir++) - { - idx = 2*ir; - sum += 2.0 * func[idx]; - } - - ModuleBase::timer::tick("Mathzone_Add1","uni_simpson"); - return sum * dr / 3.0; -} - -void Mathzone_Add1::uni_radfft -( - const int& n, - const int& aml, - const int& mshk, - const double* arr_k, - const int& mshr, - const double* ri, - const double& dr, - const double* phir, - double* phik -) -{ - ModuleBase::timer::tick ("Mathzone_Add1","uni_radfft"); - - //allocate memory - double* fi = new double[mshr]; - double* fi_cp = new double[mshr]; - double* jl = new double[mshr]; - - //function to be integrated: r^2 phir - for (int ir = 0; ir < mshr; ir++) - { - fi_cp[ir] = pow(ri[ir], n) * phir[ir]; - } - - //integration - for (int ik = 0; ik < mshk; ik++) - { - //calculate spherical bessel - ModuleBase::Sphbes::Spherical_Bessel(mshr, ri, arr_k[ik], aml, jl); - - //functions to be integrated - for (int ir = 0; ir < mshr; ir++) - { - fi[ir] = fi_cp[ir] * jl[ir]; - } - - phik[ik] = uni_simpson (fi, mshr, dr); - } - - //deallocate memory - delete[] fi; - delete[] fi_cp; - delete[] jl; - - ModuleBase::timer::tick("Mathzone_Add1","uni_radfft"); - - return; -} - -void Mathzone_Add1::Sph_Bes (double x, int lmax, double *sb, double *dsb) -{ - ModuleBase::timer::tick("Mathzone_Add1","Spherical_Bessel"); - - int m, n, nmax; - double j0, j1, sf, tmp, si, co, ix; // j0p, j1p, ix2 - - if (x < 0.0) - { - std::cout << "\nminus x is invalid for Spherical_Bessel" << std::endl; - exit(0); // mohan add 2021-05-06 - } - - const double xmin = 1E-10; - - /* find an appropriate nmax */ - nmax = lmax + 3*static_cast(x) + 20; - if (nmax < 100) nmax = 100; - - /* allocate tsb */ - double* tsb = new double[nmax+1]; - - /* if x is larger than xmin */ - - if (xmin < x) - { - /* initial values*/ - tsb[nmax] = 0.0; - tsb[nmax-1] = 1.0e-14; - - /* downward recurrence from nmax-2 to lmax+2 */ - - for (n = nmax-1; (lmax+2) < n; n--) - { - tsb[n-1] = (2.0*n + 1.0)/x*tsb[n] - tsb[n+1]; - - if (1.0e+250 < tsb[n-1]) - { - tmp = tsb[n-1]; - tsb[n-1] /= tmp; - tsb[n ] /= tmp; - } - } - - /* downward recurrence from lmax+1 to 0 */ - n = lmax + 3; - tmp = tsb[n-1]; - tsb[n-1] /= tmp; - tsb[n ] /= tmp; - - for (n = lmax+2; 0 < n; n--) - { - tsb[n-1] = (2.0*n + 1.0)/x*tsb[n] - tsb[n+1]; - - if (1.0e+250 < tsb[n-1]) - { - tmp = tsb[n-1]; - for (m = n-1; m <= lmax+1; m++) - { - tsb[m] /= tmp; - } - } - } - - /* normalization */ - si = sin(x); - co = cos(x); - ix = 1.0/x; - //ix2 = ix*ix; - j0 = si*ix; - j1 = si*ix*ix - co*ix; - - if (fabs(tsb[1]) < fabs(tsb[0])) sf = j0/tsb[0]; - else sf = j1/tsb[1]; - - /* tsb to sb */ -// for (n = 0; n <= lmax+1; n++) - for (n = 0; n <= lmax; n++) - { - sb[n] = tsb[n]*sf; - } - - /* derivative of sb */ - dsb[0] = co*ix - si*ix*ix; - // for (n = 1; n <= lmax; n++) - for (n = 1; n < lmax; n++) - { - dsb[n] = ( (double)n*sb[n-1] - (double)(n+1.0)*sb[n+1] )/(2.0*(double)n + 1.0); - } - - n = lmax; - dsb[n] = ( (double)n*sb[n-1] - (double)(n+1.0)*sf*tsb[n+1] )/(2.0*(double)n + 1.0); - } - - /* if x is smaller than xmin */ - else - { - /* sb */ - for (n = 0; n <= lmax; n++ ) - { - sb[n] = 0.0; - } - sb[0] = 1.0; - - /* derivative of sb */ - dsb[0] = 0.0; - dsb[1] = 1.0 / 3.0; - for (n = 2; n <= lmax; n++) - { -// dsb[n] = ( (double)n*sb[n-1] - (double)(n+1.0)*sb[n+1] )/(2.0*(double)n + 1.0); - dsb[n] = 0.0; - } - } - - /* free tsb */ - delete [] tsb; - - ModuleBase::timer::tick("Mathzone_Add1","Spherical_Bessel"); - - return; -} - -void Mathzone_Add1::Sbt_new -( - const int& polint_order, - const int& l, - const double* k, - const double& dk, - const int& mshk, - const double* r, - const double& dr, - const int& mshr, - const double* fr, - const int& rpow, - double* fk -) -{ - ModuleBase::timer::tick ("Mathzone_Add1","Sbt_new"); - - //check parameter - assert (mshr >= 1); - assert (l >= 0); - assert (rpow >= 0 && rpow <=2); - - //step 0 - //l is odd or even - bool parity_flag; - if (l % 2 == 0) - { - parity_flag = true; - } - else - { - parity_flag = false; - } - - ModuleBase::GlobalFunc::ZEROS (fk, mshk); - - if (polint_order != 3) - { - std::cout << "\nhigh order interpolation is not available!" << std::endl; - exit(0); // mohan add 2021-05-06 - //ModuleBase::QUIT(); - } - - /********************************** - function multiplied by power of r - for different polint_order - **********************************/ - double* fr2; - double* fr3; - - //polint_order == 1 - fr2 = new double[mshr]; - if (rpow == 0) for (int ir = 0; ir < mshr; ir++) fr2[ir] = fr[ir]*r[ir]*r[ir]; - else if (rpow == 1) for (int ir = 0; ir < mshr; ir++) fr2[ir] = fr[ir]*r[ir]; - else if (rpow == 2) for (int ir = 0; ir < mshr; ir++) fr2[ir] = fr[ir]; - - fr3 = new double[mshr]; - for (int ir = 0; ir < mshr; ir++) fr3[ir] = fr2[ir] * r[ir]; - -// const int polint_order = 3; - int nu_pol_coef = (polint_order+1)*(mshk-1); - double* polint_coef = new double[nu_pol_coef]; - - //step 1 - //start calc - if (parity_flag) - { - //even - const int n = l/2; - - //coef for interpolation - int ct = 0; - - double ft_save = fourier_cosine_transform (fr2, r, mshr, dr, k[0]); - double dft_save = -fourier_sine_transform (fr3, r, mshr, dr, k[0]); - - for (int ik = 0; ik < mshk-1; ik++) - { - double ft0 = ft_save; - double dft0 = dft_save; - - double ft1 = fourier_cosine_transform (fr2, r, mshr, dr, k[ik+1]); - double dft1 = -fourier_sine_transform (fr3, r, mshr, dr, k[ik+1]); - - //double d2k = dk*dk; - //double d3k = d2k*dk; - - double c0 = ft0; - double c1 = dft0; - double c2 = 3.0*(ft1-ft0)/dk/dk-(dft1+2.0*dft0)/dk; - double c3 = (-2.0*(ft1-ft0)/dk+(dft1+dft0))/dk/dk; - - double k2 = k[ik]*k[ik]; - double k3 = k2*k[ik]; - - polint_coef[ct] = c0-c1*k[ik]+c2*k2-c3*k3; - polint_coef[ct+1] = c1-2.0*c2*k[ik]+3.0*c3*k2; - polint_coef[ct+2] = c2-3.0*k[ik]*c3; - polint_coef[ct+3] = c3; - - //test - // double x = (k[ik]+k[ik+1])/2; - // double x = k[ik]; -// double tmp = polint_coef[ct]+x*polint_coef[ct+1]+x*x*polint_coef[ct+2]+x*x*x*polint_coef[ct+3]; - // double tmp_ana = fourier_cosine_transform (fr2, r, mshr, dr, x); - // std::cout << "\ninterp = " << tmp << " ana = " << tmp_ana << " diff = " << log(fabs(tmp-tmp_ana))/log(10); - - //update - ct += (polint_order+1); - ft_save = ft1; - dft_save = dft1; - } - - //store coefficients for calculation - double* coef = new double[n+1]; - double fac = dualfac (l-1) / dualfac (l); - - for (int j = 0; j < n; j++) - { - coef[j] = fac; - - //update - int twoj = 2*j; - fac *= -static_cast((l+twoj+1)*(l-twoj))/(twoj+2)/(twoj+1); - } - coef[n] = pow (-1.0, n) * dualfac (2*l-1) / factorial (l); - - //start calc - //special case k = 0; - if (n ==0) fk[0] = uni_simpson (fr2, mshr, dr); - else fk[0] = 0.0; - - //k > 0 - for (int j =0; j <=n; j++) - { - //initialize Snm - double Snm = 0.0; - for (int ik = 1; ik < mshk; ik++) - { - Snm += pol_seg_int (polint_order, polint_coef, 2*j, k, ik); - double k2j = pow (k[ik], 2*j+1); - - fk[ik] += coef[j] * Snm / k2j; - } - } - //free - delete [] coef; - } - else - { - //odd - const int n = (l-1)/2; - - //coef for interpolation - int ct = 0; - double ft_save, dft_save; - ft_save = fourier_sine_transform (fr2, r, mshr, dr, k[0]); - dft_save = fourier_cosine_transform (fr3, r, mshr, dr, k[0]); - - for (int ik = 0; ik < mshk-1; ik++) - { - double ft0 = ft_save; - double dft0 = dft_save; - - double ft1 = fourier_sine_transform (fr2, r, mshr, dr, k[ik+1]); - double dft1 = fourier_cosine_transform (fr3, r, mshr, dr, k[ik+1]); - - double c0, c1, c2, c3; - c0 = ft0; - c1 = dft0; - c2 = 3.0*(ft1-ft0)/dk/dk-(dft1+2.0*dft0)/dk; - c3 = (-2.0*(ft1-ft0)/dk+(dft1+dft0))/dk/dk; - - double k2 = k[ik]*k[ik]; - double k3 = k2*k[ik]; - - polint_coef[ct] = c0-c1*k[ik]+c2*k2-c3*k3; - polint_coef[ct+1] = c1-2.0*c2*k[ik]+3.0*c3*k2; - polint_coef[ct+2] = c2-3.0*k[ik]*c3; - polint_coef[ct+3] = c3; - -// double x = (k[ik]+k[ik+1])/2; -// double x = k[ik]; -// double tmp = polint_coef[ct]+x*polint_coef[ct+1]+x*x*polint_coef[ct+2]+x*x*x*polint_coef[ct+3]; -// std::cout << "\ninterp = " << tmp << " ana = " << fourier_sine_transform (fr2, r, mshr, dr, x); - //update - ct += (polint_order+1); - ft_save = ft1; - dft_save = dft1; - } - - //store coefficients for calculation - double* coef = new double[n+1]; - double fac = dualfac (l) / dualfac (l-1); - - for (int j = 0; j < n; j++) - { - coef[j] = fac; - - //update - int twoj = 2*j; - fac *= -static_cast((l+twoj+2)*(l-twoj-1))/(twoj+3)/(twoj+2); - - //test -// std::cout << "\ncoef[j] = " << coef[j] << std::endl; - } - coef[n] = pow (-1.0, n) * dualfac (2*l-1) / factorial (l); - - //start calc - //special case k =0 ; - fk[0] = 0.0; - - //k > 0 - for (int j = 0; j <= n; j++) - { - double Snm = 0.0; - for (int ik = 1; ik < mshk; ik++) - { - Snm += pol_seg_int (polint_order, polint_coef, 2*j+1, k, ik); - double k2j = pow(k[ik], 2*j+2); - - fk[ik] += coef[j] * Snm / k2j; - } - } - - //free - delete [] coef; - } - - delete [] fr2; - delete [] fr3; - delete [] polint_coef; - - ModuleBase::timer::tick ("Mathzone_Add1","Sbt_new"); - return; -} - double Mathzone_Add1::factorial (const int& l) { if (l == 0 || l == 1) return 1.0; @@ -650,248 +50,6 @@ double Mathzone_Add1::dualfac (const int& l) else return l * dualfac (l-2); } -double Mathzone_Add1::pol_seg_int -( - const int& polint_order, - const double* coef, - const int& n, - const double* k, - const int& ik -) -{ - double val = 0.0; - double kmf = pow (k[ik], n+1); - double kmb = pow (k[ik-1], n+1); - - int cstart = (polint_order+1)*(ik-1); - for (int i = 0; i <= polint_order; i++) - { - val += coef[cstart+i]*(kmf-kmb)/(n+i+1); - /* - if (ik == 110) - { - std::cout << "i = " << i << " coef = " << coef[cstart+i] << " df = " << kmf-kmb << std::endl; - } - */ - //update - kmf *= k[ik]; - kmb *= k[ik-1]; - } - - /* - if (ik == 110) - { - std::cout << "val = " << val << std::endl; - ModuleBase::QUIT (); - } - */ - return val; -} - -double Mathzone_Add1::fourier_sine_transform -( - const double* func, - const double* r, - const int& mshr, - const double& dr, - const double& k -) -{ - ModuleBase::timer::tick ("Mathzone_Add1","Fsin"); - double val = 0.0; - double* sinf = new double[mshr]; - for (int ir = 0; ir < mshr; ir++) - { - sinf[ir] = func[ir] * sin(k*r[ir]); - } - val = uni_simpson (sinf, mshr, dr); - delete[] sinf; - ModuleBase::timer::tick ("Mathzone_Add1","Fsin"); - return val; -} - -double Mathzone_Add1::fourier_cosine_transform -( - const double* func, - const double* r, - const int& mshr, - const double& dr, - const double& k -) -{ - ModuleBase::timer::tick ("Mathzone_Add1","Fcos"); - double val = 0.0; - double* cosf = new double[mshr]; - for (int ir = 0; ir < mshr; ir++) - { - cosf[ir] = func[ir] * cos(k*r[ir]); - } - - val = uni_simpson (cosf, mshr, dr); - delete[] cosf; - ModuleBase::timer::tick ("Mathzone_Add1","Fcos"); - return val; -} - -void Mathzone_Add1::test () -{ - int polint_order =3; - int dim = 2048; - int ci = 1; - int l = 0; - double rmax = 20; - double dr = rmax/dim; - - double* rad = new double[dim]; - double* func = new double[dim]; - double* fk = new double[dim]; - for (int ir = 0; ir < dim; ir++) - { - rad[ir] = ir * dr; - func[ir] = pow(rad[ir], l) * exp(-ci*rad[ir]*rad[ir]); - fk[ir] = 0.0; - } - - Sbt_new (polint_order, l, rad, dr, dim, rad, dr, dim, func, 0, fk); - - for (int ik = 0; ik < dim; ik++) - { - double diff = fk[ik]- sqrt(ModuleBase::PI/4/ci)/pow(2.0*ci, l+1)* std::pow(rad[ik], l) * exp(-rad[ik]*rad[ik]/4/ci); - std::cout << rad[ik] << " " << fk[ik] << " " << sqrt(ModuleBase::PI/4/ci)/pow(2.0*ci, l+1)*pow(rad[ik], l)*exp(-rad[ik]*rad[ik]/4/ci) - << " " << std::log(fabs(diff))/std::log(10.0) << std::endl; - } - - delete[] rad; - delete[] func; - delete[] fk; - return; -} - -void Mathzone_Add1::test2 () -{ - int polint_order =3; - int N = 200; - int ci = 1; - int l = 0; - double rmax = 20; - double dr = rmax/(N-1); - - double dk = ModuleBase::PI / rmax /2; -// double kmax = PI / dr; -// double dk = dr; - - double* rad = new double[N]; - double* kad = new double[N]; - double* func = new double[N]; - double* fk = new double[N]; - double* fr = new double[N]; - for (int ir = 0; ir < N; ir++) - { - rad[ir] = ir * dr; - kad[ir] = ir * dk; - func[ir] = pow(rad[ir], l) * exp(-ci*rad[ir]*rad[ir]); - fk[ir] = 0.0; - fr[ir] = 0.0; - } - - Sbt_new (polint_order, l, kad, dk, N, rad, dr, N, func, 0, fk); - -/* - for (int ik = 0; ik < N; ik++) - { - double diff = fk[ik]- sqrt(PI/4/ci)/pow(2*ci, l+1)*pow(kad[ik], l)*exp(-kad[ik]*kad[ik]/4/ci); - std::cout << kad[ik] << " " << fk[ik] << " " << sqrt(PI/4/ci)/pow(2*ci, l+1)*pow(kad[ik], l)*exp(-kad[ik]*kad[ik]/4/ci) - << " " << log(fabs(diff))/log(10) << std::endl; - } - ModuleBase::QUIT (); -*/ - Sbt_new (polint_order, l, rad, dr, N, kad, dk, N, fk, 0, fr); - - for (int ir = 0; ir < N; ir++) - { - std::cout << ir*dr << " " << func[ir] << " " << fr[ir] *2.0 / ModuleBase::PI << " " << std::log(fabs(fr[ir]*2.0/ModuleBase::PI-func[ir]))/std::log(10.0) << std::endl; - } - - - delete[] rad; - delete[] func; - delete[] fk; - delete[] fr; - delete[] kad; - return; -} - -double Mathzone_Add1::Polynomial_Interpolation -( - const double* xa, - const double* ya, - const int& n, - const double& x -) -{ - ModuleBase::timer::tick("Mathzone_Add1","Polynomial_Interpolation"); - - int i, m, ns; - double den, dif, dift, ho, hp, w, rs, drs; - - //zero offset - const double* Cxa = xa - 1; - const double* Cya = ya - 1; - - double* cn = new double[n+1]; - double* dn = new double[n+1]; - - ns = 1; - dif = fabs(x - Cxa[1]); - - for(i = 1; i <= n; i++) - { - dift = fabs(x - Cxa[i]); - if(dift < dif) - { - ns = i; - dif = dift; - } - cn[i] = Cya[i]; - dn[i] = Cya[i]; - } - - rs = Cya[ns--]; - - for(m = 1; m < n; m++) - { - for(i = 1; i <= n-m; i++) - { - ho = Cxa[i] - x; - hp = Cxa[i+m] - x; - w = cn[i+1] - dn[i]; - - den = ho - hp; - if(den == 0.0) - { - std::cout << "Two Xs are equal" << std::endl; - // ModuleBase::WARNING_QUIT("Mathzone_Add1::Polynomial_Interpolation","Two Xs are equal"); - exit(0); // mohan update 2021-05-06 - } - den = w / den; - - dn[i] = hp * den; - cn[i] = ho * den; - } - if(2 * ns < n-m) drs = cn[ns+1]; - else drs = dn[ns--]; - - rs += drs; - } - - delete[] cn; - delete[] dn; - - return rs; - ModuleBase::timer::tick("Mathzone_Add1","Polynomial_Interpolation"); - -} - void Mathzone_Add1::SplineD2 // modified by pengfei 13-8-8 add second derivative as a condition ( @@ -1024,124 +182,7 @@ void Mathzone_Add1::Cubic_Spline_Interpolation ModuleBase::timer::tick("Mathzone_Add1","Cubic_Spline_Interpolation"); } -// Interpolation for Numerical Orbitals -double Mathzone_Add1::RadialF -( - const double* rad, - const double* rad_f, - const int& msh, - const int& l, - const double& R -) -{ - ModuleBase::timer::tick("Mathzone_Add1","RadialF"); - - int mp_min, mp_max, m; - double h1, h2, h3, f1, f2, f3, f4; - double g1, g2, x1, x2, y1, y2, f; - double c, result; - - mp_min = 0; - mp_max = msh - 1; - - //assume psir behaves like r**l - if (R < rad[0]) - { - if (l == 0) - { - f = rad_f[0]; - } - else - { - c = rad_f[0] / pow(rad[0], l); - f = pow(R, l) * c; - } - } - else if (rad[mp_max] < R) - { - f = 0.0; - } - else - { - do - { - m = (mp_min + mp_max)/2; - if (rad[m] < R) mp_min = m; - else mp_max = m; - } - while((mp_max-mp_min)!=1); - m = mp_max; - - if (m < 2) - { - m = 2; - } - else if (msh <= m) - { - m = msh - 2; - } - - /**************************************************** - Spline like interpolation - ****************************************************/ - - if (m == 1) - { - h2 = rad[m] - rad[m-1]; - h3 = rad[m+1] - rad[m]; - - f2 = rad_f[m-1]; - f3 = rad_f[m]; - f4 = rad_f[m+1]; - - h1 = -(h2+h3); - f1 = f4; - } - else if (m == (msh-1)) - { - h1 = rad[m-1] - rad[m-2]; - h2 = rad[m] - rad[m-1]; - - f1 = rad_f[m-2]; - f2 = rad_f[m-1]; - f3 = rad_f[m]; - - h3 = -(h1+h2); - f4 = f1; - } - else - { - h1 = rad[m-1] - rad[m-2]; - h2 = rad[m] - rad[m-1]; - h3 = rad[m+1] - rad[m]; - - f1 = rad_f[m-2]; - f2 = rad_f[m-1]; - f3 = rad_f[m]; - f4 = rad_f[m+1]; - } - - //Calculate the value at R - - g1 = ((f3-f2)*h1/h2 + (f2-f1)*h2/h1)/(h1+h2); - g2 = ((f4-f3)*h2/h3 + (f3-f2)*h3/h2)/(h2+h3); - - x1 = R - rad[m-1]; - x2 = R - rad[m]; - y1 = x1/h2; - y2 = x2/h2; - - f = y2*y2*(3.0*f2 + h2*g1 + (2.0*f2 + h2*g1)*y2) - + y1*y1*(3.0*f3 - h2*g2 - (2.0*f3 - h2*g2)*y1); - } - - result = f; - - ModuleBase::timer::tick("Mathzone_Add1","RadialF"); - return result; -} - -// Interpolation for Numerical Orbitals +/// Interpolation for Numerical Orbitals double Mathzone_Add1::Uni_RadialF ( const double* old_phi, @@ -1368,4 +409,4 @@ void Mathzone_Add1::Uni_Deriv_Phi ModuleBase::timer::tick("Mathzone_Add1", "Uni_Deriv_Phi"); } -} \ No newline at end of file +} diff --git a/source/module_base/mathzone_add1.h b/source/module_base/mathzone_add1.h index 68d76560d45..05565ff3c2f 100644 --- a/source/module_base/mathzone_add1.h +++ b/source/module_base/mathzone_add1.h @@ -1,8 +1,8 @@ #ifndef MATHZONE_ADD1_H #define MATHZONE_ADD1_H -#include #include +#include namespace ModuleBase { @@ -10,188 +10,74 @@ namespace ModuleBase /************************************************************************ LiaoChen add @ 2010/03/09 to add efficient functions in LCAO calculation ************************************************************************/ -//Only used in module_orbital +// Only used in module_orbital class Mathzone_Add1 { -public: - + public: Mathzone_Add1(); ~Mathzone_Add1(); - static void Sph_Bes (double x, int lmax, double *sb, double *dsb); - - //calculate jlx and its derivatives - static void Spherical_Bessel - ( - const int &msh, //number of grid points - const double *r,//radial grid - const double &q, // - const int &l, //angular momentum - double *sj, //jl(1:msh) = j_l(q*r(i)),spherical bessel function - double *sjp - ); - - /*********************************************** - function : integrate function within [a,b] - formula : \int func(r) dr - ***********************************************/ - static double uni_simpson (const double* func, const int& mshr, const double& dr); - - /******************************************************** - function : radial fourier transform - formula : F(k)_{n,l} = \int r^{n} dr jl(kr) \phi(r) - input : - int n (exponential index) - int aml (angular momentum) - int mshk (number of kpoints) - double* arr_k (an array of kpoints) - int mshr (number of grids) - double* ri (radial points) - double dr (delta r) - double* phir (function to be transformed) - output : - double* phik (with size of mshk) - ********************************************************/ - static void uni_radfft - ( - const int& n, - const int& aml, - const int& mshk, - const double* arr_k, - const int& mshr, - const double* ri, - const double& dr, - const double* phir, - double* phik - ); - - static void Sbt_new - ( - const int& polint_order, - const int& l, - const double* k, - const double& dk, - const int& mshk, - const double* r, - const double& dr, - const int& mshr, - const double* fr, - const int& rpow, - double* fk - ); + static double dualfac(const int& l); + static double factorial(const int& l); + /** + * @brief calculate second derivatives for cubic + * spline interpolation + * + * @param[in] rad x before interpolation + * @param[in] rad_f f(x) before interpolation + * @param[in] mesh number of x before interpolation + * @param[in] yp1 f'(0) boundary condition + * @param[in] ypn f'(n) boundary condition + * @param[out] y2 f''(x) + */ + static void SplineD2(const double* rad, + const double* rad_f, + const int& mesh, + const double& yp1, + const double& ypn, + double* y2); - static double dualfac (const int& l); - static double factorial (const int& l); - static double fourier_cosine_transform - ( - const double* func, - const double* r, - const int& mshr, - const double& dr, - const double& k - ); - - static double fourier_sine_transform - ( - const double* func, - const double* r, - const int& mshr, - const double& dr, - const double& k - ); - -/* - static void RecursiveRadfft - ( - const int& l, - const double** sb, - const double** dsb, - const double* arr_k, - const double* kpoint, - const int& kmesh, - const double& dk, - double& rs, - double& drs - ); -*/ + /** + * @brief cubic spline interpolation + * + * @param[in] rad x before interpolation + * @param[in] rad_f f(x) before inpterpolation + * @param[in] y2 f''(x) before interpolation + * @param[in] mesh number of x before interpolation + * @param[in] r x after interpolation + * @param[in] rsize number of x after interpolation + * @param[out] y f(x) after interpolation + * @param[out] dy f'(x) after interpolation + */ + static void Cubic_Spline_Interpolation(const double* const rad, + const double* const rad_f, + const double* const y2, + const int& mesh, + const double* const r, + const int& rsize, + double* const y, + double* const dy); - static double Polynomial_Interpolation - ( - const double* xa, - const double* ya, - const int& n, - const double& x - ); - - static void SplineD2 - ( - const double *rad, - const double *rad_f, - const int& mesh, - const double &yp1, - const double &ypn, - double* y2 - ); + /** + * @brief "spline like interpolation" of a uniform + * funcation of r + * + * @param[in] rad_f f(x) before interpolation + * @param[in] msh number of x known + * @param[in] dr uniform distance of x + * @param R f(R) is to be calculated + * @return double f(R) + */ + static double Uni_RadialF(const double* rad_f, const int& msh, const double& dr, const double& R); - static void Cubic_Spline_Interpolation - ( - const double * const rad, - const double * const rad_f, - const double * const y2, - const int& mesh, - const double * const r, - const int& rsize, - double * const y, - double * const dy - ); - - static double RadialF - ( - const double* rad, - const double* rad_f, - const int& msh, - const int& l, - const double& R - ); - - static double Uni_RadialF - ( - const double* rad_f, - const int& msh, - const double& dr, - const double& R - ); - - static void test (); - static void test2 (); - - - static void Uni_Deriv_Phi - ( - const double *radf, - const int &mesh, - const double &dr, - const int &nd, - double* phind - ); - - private: - const static int sph_lmax = 20; - static double** c_ln_c; - static double** c_ln_s; - static bool flag_jlx_expand_coef; + static void Uni_Deriv_Phi(const double* radf, const int& mesh, const double& dr, const int& nd, double* phind); - static void expand_coef_jlx (); - static double pol_seg_int - ( - const int& polint_order, - const double* coef, - const int& n, - const double* k, - const int& ik - ); + private: + const static int sph_lmax = 20; + static double** c_ln_c; + static double** c_ln_s; }; -} +} // namespace ModuleBase #endif diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index 4ddcc55e72a..46930a321c3 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -63,4 +63,9 @@ AddTest( TARGET base_mathzone LIBS ${math_libs} SOURCES mathzone_test.cpp ../mathzone.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../memory.cpp ../timer.cpp -) \ No newline at end of file +) +AddTest( + TARGET base_mathzone_add1 + LIBS ${math_libs} + SOURCES mathzone_add1_test.cpp ../mathzone_add1.cpp ../math_sphbes.cpp ../mathzone.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../memory.cpp ../timer.cpp +) diff --git a/source/module_base/test/mathzone_add1_test.cpp b/source/module_base/test/mathzone_add1_test.cpp new file mode 100644 index 00000000000..9a2e45fe45f --- /dev/null +++ b/source/module_base/test/mathzone_add1_test.cpp @@ -0,0 +1,222 @@ +#include "../mathzone_add1.h" +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +/************************************************ + * unit test of class Mathzone_Add1 + ***********************************************/ + +/** + * - Tested Functions: + * - CubicSplineBoundary1 + * - call SplineD2 and Cubic_Spline_Interpolation + * - to interpolate a function using the first + * - kind of boundary condition:f'(0) = f'(n) = 0.0 + * - CubicSplineBoundary2 + * - call SplineD2 and Cubic_Spline_Interpolation + * - to interpolate a function using the second + * - kind of boundary condition:f''(0) = f''(n) = 0.0 + * - UniRadialF + * - call Uni_RadialF to interpolate the radial part + * - of an atomic orbital function, whose discrete r + * - points are uniform. + * - Factorial + * - calculate the factorial of an integer + * - DualFac + * - calculate the double factorial or + * - semifactorial of an integer + */ + +class MathzoneAdd1Test : public testing::Test +{ +protected: + const int MaxInt = 100; + int nr_in = 17; + int nr_out = 161; + double *r_in = new double[nr_in]; + double *r_out = new double[nr_out]; + double *y2 = new double[nr_in]; + double *psi_in = new double[nr_in]; + double *psi_out = new double[nr_out]; + double *dpsi = new double[nr_out]; + void TearDown() + { + delete[] r_in; + delete[] r_out; + delete[] y2; + delete[] psi_in; + delete[] psi_out; + delete[] dpsi; + } +}; + +/// first kind boundary condition: f'(0) = f'(n) = 0.0 +TEST_F(MathzoneAdd1Test, CubicSplineBoundary1) +{ + // data from abacus/tests/integrate/tools/PP_ORB/Si_gga_8au_60Ry_2s2p1d.orb + // data for d orbital of Si : L = 2, N = 0 + psi_in[0] = 0; + psi_in[1] = -2.583946346740e-01; + psi_in[2] = -4.570087269049e-01; + psi_in[3] = -4.374680500187e-01; + psi_in[4] = -3.587829079989e-01; + psi_in[5] = -2.581772323753e-01; + psi_in[6] = -1.616203660437e-01; + psi_in[7] = -9.108838081645e-02; + psi_in[8] = -5.202206559586e-02; + psi_in[9] = -3.126875315134e-02; + psi_in[10] = -1.860199873973e-02; + psi_in[11] = -8.049945178799e-03; + psi_in[12] = -1.652010824028e-03; + psi_in[13] = 1.495515249035e-03; + psi_in[14] = 3.221037475903e-03; + psi_in[15] = 3.802139894646e-03; + psi_in[16] = 0; + for (int i=0; i< nr_in; i++) + { + r_in[i] = i*0.5; + //std::cout<< r_in[i] << " " << psi_in[i] << std::endl; // for plotting + } + for (int i=0; i< nr_out; i++) + { + r_out[i] = i*0.05; + } + ModuleBase::Mathzone_Add1::SplineD2(r_in,psi_in,nr_in,0.0,0.0,y2); + //std::cout << "y2[0] "<< y2[0] << " y2[nr_in] "<< y2[nr_in-1] << std::endl; // for checking + ModuleBase::Mathzone_Add1::Cubic_Spline_Interpolation(r_in,psi_in,y2,nr_in,r_out,nr_out,psi_out,dpsi); + for (int i=0; i< nr_out; i++) + { + int j = i/10; + if(i%10==0) { + EXPECT_EQ(psi_in[j],psi_out[i]); + } + //std::cout<< r_out[i] << " " << psi_out[i] << std::endl; // for plotting + } + EXPECT_NEAR(dpsi[0],0.0,1e-15); + EXPECT_NEAR(dpsi[nr_out-1],0.0,1e-15); + //std::cout< Date: Thu, 17 Feb 2022 16:47:48 +0800 Subject: [PATCH 33/52] test:modify mathzone_add1_test.cpp --- .../module_base/test/mathzone_add1_test.cpp | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/source/module_base/test/mathzone_add1_test.cpp b/source/module_base/test/mathzone_add1_test.cpp index 9a2e45fe45f..c10c86c0082 100644 --- a/source/module_base/test/mathzone_add1_test.cpp +++ b/source/module_base/test/mathzone_add1_test.cpp @@ -33,12 +33,21 @@ class MathzoneAdd1Test : public testing::Test const int MaxInt = 100; int nr_in = 17; int nr_out = 161; - double *r_in = new double[nr_in]; - double *r_out = new double[nr_out]; - double *y2 = new double[nr_in]; - double *psi_in = new double[nr_in]; - double *psi_out = new double[nr_out]; - double *dpsi = new double[nr_out]; + double *r_in; + double *r_out; + double *y2; + double *psi_in; + double *psi_out; + double *dpsi; + void SetUp() + { + r_in = new double[nr_in]; + r_out = new double[nr_out]; + y2 = new double[nr_in]; + psi_in = new double[nr_in]; + psi_out = new double[nr_out]; + dpsi = new double[nr_out]; + } void TearDown() { delete[] r_in; From a3cd7d80407420b7ab50ba1391219ed5363aaeed Mon Sep 17 00:00:00 2001 From: xingliang Date: Thu, 17 Feb 2022 17:05:39 +0800 Subject: [PATCH 34/52] use M_PI in math.h --- source/module_base/test/math_ylmreal_test.cpp | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/source/module_base/test/math_ylmreal_test.cpp b/source/module_base/test/math_ylmreal_test.cpp index 455dfb7e95f..b8b85e18d54 100644 --- a/source/module_base/test/math_ylmreal_test.cpp +++ b/source/module_base/test/math_ylmreal_test.cpp @@ -2,8 +2,8 @@ #include"../vector3.h" #include"../matrix.h" #include"gtest/gtest.h" +#include -#define PI 3.141592653589793238462643383279502884197169399 #define doublethreshold 1e-12 /************************************************ @@ -49,31 +49,31 @@ class YlmRealTest : public testing::Test //https://formulasearchengine.com/wiki/Table_of_spherical_harmonics //multipy the Condon–Shortley phase (-1)^m inline double norm(const double &x, const double &y, const double &z) {return sqrt(x*x + y*y + z*z);} - double y00(const double &x, const double &y, const double &z) {return 1.0/2.0/sqrt(PI);} - double y10(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return sqrt(3.0/(4.0*PI)) * z / r;} - double y11(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*sqrt(3.0/(4.*PI)) * x / r;} - double y1m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*sqrt(3./(4.*PI)) * y / r;} // y1m1 means Y1,-1 - double y20(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(5./PI) * (-1.*x*x - y*y + 2.*z*z) / (r*r);} - double y21(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./2. * sqrt(15./PI) * (z*x) / (r*r);} - double y2m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./2. * sqrt(15./PI) * (z*y) / (r*r);} - double y22(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(15./PI) * (x*x - y*y) / (r*r);} - double y2m2(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./2. * sqrt(15./PI) * (x*y) / (r*r);} - double y30(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(7./PI) * z*(2.*z*z-3.*x*x-3.*y*y) / (r*r*r);} - double y31(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(21./2./PI) * x*(4.*z*z-x*x-y*y) / (r*r*r);} - double y3m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(21./2./PI) * y*(4.*z*z-x*x-y*y) / (r*r*r);} - double y32(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(105./PI) * (x*x - y*y)*z / (r*r*r);} - double y3m2(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./2. * sqrt(105./PI) * x*y*z / (r*r*r);} - double y33(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(35./2./PI) * x*(x*x - 3.*y*y) / (r*r*r);} - double y3m3(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(35./2./PI) * y*(3.*x*x - y*y) / (r*r*r);} - double y40(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./16.*sqrt(1./PI) * (35.*z*z*z*z - 30.*z*z*r*r + 3*r*r*r*r) / (r*r*r*r);} - double y41(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(5./2./PI) * x*z*(7.*z*z - 3*r*r) / (r*r*r*r);} - double y4m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(5./2./PI) * y*z*(7.*z*z - 3.*r*r) / (r*r*r*r);} - double y42(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./8.*sqrt(5./PI) * (x*x-y*y)*(7.*z*z-r*r) / (r*r*r*r);} - double y4m2(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./4.*sqrt(5./PI) * x*y*(7.*z*z - r*r) / (r*r*r*r);} - double y43(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(35./2./PI) * x*z*(x*x - 3.*y*y) / (r*r*r*r);} - double y4m3(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(35./2./PI) * y*z*(3.*x*x - y*y) / (r*r*r*r);} - double y44(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./16.*sqrt(35./PI) * (x*x*(x*x - 3.*y*y) - y*y*(3.*x*x-y*y)) / (r*r*r*r);} - double y4m4(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./4.*sqrt(35./PI) * x*y*(x*x - y*y) / (r*r*r*r);} + double y00(const double &x, const double &y, const double &z) {return 1.0/2.0/sqrt(M_PI);} + double y10(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return sqrt(3.0/(4.0*M_PI)) * z / r;} + double y11(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*sqrt(3.0/(4.*M_PI)) * x / r;} + double y1m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*sqrt(3./(4.*M_PI)) * y / r;} // y1m1 means Y1,-1 + double y20(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(5./M_PI) * (-1.*x*x - y*y + 2.*z*z) / (r*r);} + double y21(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./2. * sqrt(15./M_PI) * (z*x) / (r*r);} + double y2m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./2. * sqrt(15./M_PI) * (z*y) / (r*r);} + double y22(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(15./M_PI) * (x*x - y*y) / (r*r);} + double y2m2(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./2. * sqrt(15./M_PI) * (x*y) / (r*r);} + double y30(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(7./M_PI) * z*(2.*z*z-3.*x*x-3.*y*y) / (r*r*r);} + double y31(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(21./2./M_PI) * x*(4.*z*z-x*x-y*y) / (r*r*r);} + double y3m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(21./2./M_PI) * y*(4.*z*z-x*x-y*y) / (r*r*r);} + double y32(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(105./M_PI) * (x*x - y*y)*z / (r*r*r);} + double y3m2(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./2. * sqrt(105./M_PI) * x*y*z / (r*r*r);} + double y33(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(35./2./M_PI) * x*(x*x - 3.*y*y) / (r*r*r);} + double y3m3(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(35./2./M_PI) * y*(3.*x*x - y*y) / (r*r*r);} + double y40(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./16.*sqrt(1./M_PI) * (35.*z*z*z*z - 30.*z*z*r*r + 3*r*r*r*r) / (r*r*r*r);} + double y41(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(5./2./M_PI) * x*z*(7.*z*z - 3*r*r) / (r*r*r*r);} + double y4m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(5./2./M_PI) * y*z*(7.*z*z - 3.*r*r) / (r*r*r*r);} + double y42(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./8.*sqrt(5./M_PI) * (x*x-y*y)*(7.*z*z-r*r) / (r*r*r*r);} + double y4m2(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./4.*sqrt(5./M_PI) * x*y*(7.*z*z - r*r) / (r*r*r*r);} + double y43(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(35./2./M_PI) * x*z*(x*x - 3.*y*y) / (r*r*r*r);} + double y4m3(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(35./2./M_PI) * y*z*(3.*x*x - y*y) / (r*r*r*r);} + double y44(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./16.*sqrt(35./M_PI) * (x*x*(x*x - 3.*y*y) - y*y*(3.*x*x-y*y)) / (r*r*r*r);} + double y4m4(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./4.*sqrt(35./M_PI) * x*y*(x*x - y*y) / (r*r*r*r);} void SetUp() { From 5d7588c1313a39c763e9f142d04f31487abae03f Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Mon, 21 Feb 2022 17:29:07 +0800 Subject: [PATCH 35/52] Fix #736: memory free by wrong index --- source/src_pw/hamilt_pw.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/src_pw/hamilt_pw.cu b/source/src_pw/hamilt_pw.cu index 7ef18d93d43..849cc941d8f 100644 --- a/source/src_pw/hamilt_pw.cu +++ b/source/src_pw/hamilt_pw.cu @@ -1553,7 +1553,7 @@ void Hamilt_PW::add_nonlocal_pp_cuda( hpsi_in, GlobalC::wf.npwx)); } - if(m == 1) + if(m != 1) { CHECK_CUDA(cudaFree(ps)); } From 9a14a1acbc2de02dd849812830d0ecec36c8e363 Mon Sep 17 00:00:00 2001 From: xingliang Date: Tue, 22 Feb 2022 16:41:03 +0800 Subject: [PATCH 36/52] test: add the unit test and comments for ylm.h range: source/module_base --- source/module_base/test/CMakeLists.txt | 4 +- source/module_base/test/math_ylmreal_test.cpp | 203 ++++++++++++++++-- source/module_base/ylm.h | 85 ++++++-- 3 files changed, 259 insertions(+), 33 deletions(-) diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index ddfa687934d..e8cd96a75da 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -73,7 +73,7 @@ AddTest( SOURCES math_polyint_test.cpp ../math_polyint.cpp ../realarray.cpp ../timer.cpp ) AddTest( - TARGET base_math_ylmreal + TARGET base_ylmreal LIBS ${math_libs} - SOURCES math_ylmreal_test.cpp ../math_ylmreal.cpp ../realarray.cpp ../timer.cpp ../matrix.cpp + SOURCES math_ylmreal_test.cpp ../math_ylmreal.cpp ../ylm.cpp ../realarray.cpp ../timer.cpp ../matrix.cpp ../vector3.h ) \ No newline at end of file diff --git a/source/module_base/test/math_ylmreal_test.cpp b/source/module_base/test/math_ylmreal_test.cpp index b8b85e18d54..c38c56162e6 100644 --- a/source/module_base/test/math_ylmreal_test.cpp +++ b/source/module_base/test/math_ylmreal_test.cpp @@ -1,4 +1,5 @@ #include"../math_ylmreal.h" +#include"../ylm.h" #include"../vector3.h" #include"../matrix.h" #include"gtest/gtest.h" @@ -21,6 +22,13 @@ * - Ylm_Real * - Ylm_Real2 * - rlylm + * + * - Tested functions of class Ylm + * - get_ylm_real + * - sph_harm + * - rl_sph_harm + * - grad_rl_sph_harm + * - */ @@ -37,17 +45,19 @@ class YlmRealTest : public testing::Test { protected: - int lmax = 7; - int ng = 4; //test the 4 selected points on the sphere - int nylm ; // total Ylm number; - ModuleBase::matrix ylm; - ModuleBase::Vector3 *g; - double *ref; - double *rly; + int lmax = 7; //maximum angular quantum number + int ng = 4; //test the 4 selected points on the sphere + int nylm = 64; //total Ylm number; + + ModuleBase::matrix ylm; //Ylm + ModuleBase::Vector3 *g; //vectors of the 4 points + double *ref; //reference of Ylm + double *rly; //Ylm + double (*rlgy)[3]; //the gradient of Ylm + std::vector rlyvector; //Ylm + std::vector> rlgyvector; //the gradient of Ylm //Ylm function - //https://formulasearchengine.com/wiki/Table_of_spherical_harmonics - //multipy the Condon–Shortley phase (-1)^m inline double norm(const double &x, const double &y, const double &z) {return sqrt(x*x + y*y + z*z);} double y00(const double &x, const double &y, const double &z) {return 1.0/2.0/sqrt(M_PI);} double y10(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return sqrt(3.0/(4.0*M_PI)) * z / r;} @@ -75,9 +85,108 @@ class YlmRealTest : public testing::Test double y44(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./16.*sqrt(35./M_PI) * (x*x*(x*x - 3.*y*y) - y*y*(3.*x*x-y*y)) / (r*r*r*r);} double y4m4(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./4.*sqrt(35./M_PI) * x*y*(x*x - y*y) / (r*r*r*r);} + //the reference values are calculated by ModuleBase::Ylm::grad_rl_sph_harm + //1st dimension: example, 2nd dimension: Ylm, 3rd dimension: dx/dy/dz + double rlgyref[4][64][3] = { + { { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 4.88603e-01}, {-4.88603e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -4.88603e-01, 0.00000e+00}, {-6.30783e-01, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, -1.09255e+00}, + { 0.00000e+00, -0.00000e+00, 0.00000e+00}, { 1.09255e+00, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, 1.09255e+00, -0.00000e+00}, + {-0.00000e+00, 0.00000e+00, -1.11953e+00}, { 1.37114e+00, 0.00000e+00, -0.00000e+00}, { 0.00000e+00, 4.57046e-01, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 1.44531e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, {-1.77013e+00, 0.00000e+00, -0.00000e+00}, + { 0.00000e+00, -1.77013e+00, 0.00000e+00}, { 1.26943e+00, 0.00000e+00, -0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.00714e+00}, + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, {-1.89235e+00, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, -9.46175e-01, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, -1.77013e+00}, { 0.00000e+00, -0.00000e+00, 0.00000e+00}, { 2.50334e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 2.50334e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 1.75425e+00}, {-2.26473e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -4.52947e-01, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, -2.39677e+00}, {-0.00000e+00, -0.00000e+00, 0.00000e+00}, + { 2.44619e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 1.46771e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.07566e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, {-3.28191e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -3.28191e+00, 0.00000e+00}, + {-1.90708e+00, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, -2.91311e+00}, { 0.00000e+00, -0.00000e+00, 0.00000e+00}, + { 2.76362e+00, 0.00000e+00, -0.00000e+00}, {-0.00000e+00, 9.21205e-01, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.76362e+00}, + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, {-3.02739e+00, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, -2.01826e+00, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, -2.36662e+00}, { 0.00000e+00, -0.00000e+00, 0.00000e+00}, { 4.09910e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 4.09910e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, -2.38995e+00}, { 3.16161e+00, 0.00000e+00, -0.00000e+00}, + { 0.00000e+00, 4.51658e-01, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 3.31900e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-3.28564e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -1.40813e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, -3.11349e+00}, + {-0.00000e+00, -0.00000e+00, 0.00000e+00}, { 3.63241e+00, 0.00000e+00, -0.00000e+00}, { 0.00000e+00, 2.59458e+00, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 2.64596e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, {-4.95014e+00, 0.00000e+00, -0.00000e+00}, + { 0.00000e+00, -4.95014e+00, 0.00000e+00} + }, + { + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 4.88603e-01}, {-4.88603e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -4.88603e-01, 0.00000e+00}, { 0.00000e+00, -6.30783e-01, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -0.00000e+00, -1.09255e+00}, { 0.00000e+00, -1.09255e+00, 0.00000e+00}, { 1.09255e+00, 0.00000e+00, -0.00000e+00}, + { 0.00000e+00, -0.00000e+00, -1.11953e+00}, { 4.57046e-01, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 1.37114e+00, -0.00000e+00}, + { 0.00000e+00, -0.00000e+00, -1.44531e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 1.77013e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 1.77013e+00, 0.00000e+00}, { 0.00000e+00, 1.26943e+00, -0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 2.00714e+00}, { 0.00000e+00, 1.89235e+00, -0.00000e+00}, {-9.46175e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 1.77013e+00}, { 0.00000e+00, 2.50334e+00, -0.00000e+00}, + {-2.50334e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 1.75425e+00}, {-4.52947e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -2.26473e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.39677e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-1.46771e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -2.44619e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.07566e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, {-3.28191e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -3.28191e+00, 0.00000e+00}, + { 0.00000e+00, -1.90708e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -0.00000e+00, -2.91311e+00}, + { 0.00000e+00, -2.76362e+00, 0.00000e+00}, { 9.21205e-01, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -0.00000e+00, -2.76362e+00}, { 0.00000e+00, -3.02739e+00, 0.00000e+00}, { 2.01826e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -0.00000e+00, -2.36662e+00}, { 0.00000e+00, -4.09910e+00, 0.00000e+00}, + { 4.09910e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -0.00000e+00, -2.38995e+00}, { 4.51658e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 3.16161e+00, -0.00000e+00}, { 0.00000e+00, -0.00000e+00, -3.31900e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, + { 1.40813e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 3.28564e+00, -0.00000e+00}, { 0.00000e+00, -0.00000e+00, -3.11349e+00}, + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 2.59458e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 3.63241e+00, -0.00000e+00}, + { 0.00000e+00, 0.00000e+00, -2.64596e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 4.95014e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 4.95014e+00, -0.00000e+00} + }, + { + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 4.88603e-01}, {-4.88603e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -4.88603e-01, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 1.26157e+00}, {-1.09255e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -1.09255e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.22045e-16}, {-0.00000e+00, 0.00000e+00, -0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 2.23906e+00}, {-1.82818e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -1.82818e+00, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 8.81212e-16}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, {-1.84324e-16, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 5.55112e-17, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 3.38514e+00}, {-2.67619e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -2.67619e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.30756e-15}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-5.52973e-16, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 1.66533e-16, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 4.67801e+00}, {-3.62357e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -3.62357e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 4.87108e-15}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-1.22267e-15, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 3.68219e-16, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, { 4.93038e-32, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -6.16298e-33, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 6.10264e+00}, {-4.66097e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -4.66097e+00, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 8.98664e-15}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, {-2.30221e-15, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 6.93334e-16, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + { 1.77767e-31, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -2.22209e-32, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 7.64784e+00}, {-5.78122e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -5.78122e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 1.51096e-14}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-3.91011e-15, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 1.17757e-15, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, { 4.67737e-31, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -5.84671e-32, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 1.13319e-47, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -1.41649e-48, 0.00000e+00} + }, + { + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 4.88603e-01}, {-4.88603e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -4.88603e-01, 0.00000e+00}, { 3.64183e-01, 3.64183e-01, -7.28366e-01}, { 6.30783e-01, -0.00000e+00, 6.30783e-01}, + {-0.00000e+00, 6.30783e-01, 6.30783e-01}, {-6.30783e-01, 6.30783e-01, -1.66533e-16}, {-6.30783e-01, -6.30783e-01, 0.00000e+00}, + {-7.46353e-01, -7.46353e-01, 0.00000e+00}, { 0.00000e+00, 3.04697e-01, -1.21879e+00}, { 3.04697e-01, 0.00000e+00, -1.21879e+00}, + { 9.63537e-01, -9.63537e-01, 4.01253e-16}, { 9.63537e-01, 9.63537e-01, 9.63537e-01}, {-4.44089e-16, 1.18009e+00, -2.22045e-16}, + {-1.18009e+00, -1.11022e-16, 0.00000e+00}, { 4.88603e-01, 4.88603e-01, 1.30294e+00}, {-1.03006e+00, -7.72548e-01, 7.72548e-01}, + {-7.72548e-01, -1.03006e+00, 7.72548e-01}, {-7.28366e-01, 7.28366e-01, -5.25363e-16}, {-3.64183e-01, -3.64183e-01, -2.18510e+00}, + { 7.69185e-16, -2.04397e+00, -6.81324e-01}, { 2.04397e+00, 1.92296e-16, 6.81324e-01}, { 9.63537e-01, 9.63537e-01, -1.44756e-16}, + {-9.63537e-01, 9.63537e-01, -5.55112e-17}, { 5.19779e-01, 5.19779e-01, -1.81923e+00}, { 1.40917e+00, 8.05238e-01, 8.05238e-01}, + { 8.05238e-01, 1.40917e+00, 8.05238e-01}, { 0.00000e+00, -4.44089e-16, 3.24739e-16}, {-1.06523e+00, -1.06523e+00, 2.13046e+00}, + {-2.17439e-01, 1.73951e+00, 1.73951e+00}, {-1.73951e+00, 2.17439e-01, -1.73951e+00}, {-1.84503e+00, -1.84503e+00, -9.22517e-01}, + { 1.84503e+00, -1.84503e+00, 6.58625e-16}, { 1.45863e+00, 1.11022e-15, 0.00000e+00}, {-8.88178e-16, 1.45863e+00, 0.00000e+00}, + {-1.46807e+00, -1.46807e+00, 5.87227e-01}, {-4.48502e-01, -3.36617e-16, -2.24251e+00}, {-3.36617e-16, -4.48502e-01, -2.24251e+00}, + { 7.09144e-01, -7.09144e-01, 1.87222e-16}, { 2.12743e+00, 2.12743e+00, -9.38779e-16}, { 7.09144e-01, -5.11006e-16, -2.12743e+00}, + { 1.02201e-15, -7.09144e-01, 2.12743e+00}, { 1.81260e+00, 1.81260e+00, 2.58943e+00}, {-2.07154e+00, 2.07154e+00, -1.66969e-15}, + {-3.03637e+00, -2.31111e-15, -6.07275e-01}, { 1.84889e-15, -3.03637e+00, -6.07275e-01}, { 1.05183e+00, -1.05183e+00, 5.77778e-17}, + { 1.05183e+00, 1.05183e+00, 4.03986e-17}, { 1.27464e+00, 1.27464e+00, 1.69952e+00}, {-1.28472e+00, -1.20442e+00, 1.92707e+00}, + {-1.20442e+00, -1.28472e+00, 1.92707e+00}, {-8.52285e-01, 8.52285e-01, -6.74704e-16}, {-1.50789e+00, -1.50789e+00, -2.95022e+00}, + {-1.11260e+00, -2.08612e+00, 9.27164e-01}, { 2.08612e+00, 1.11260e+00, -9.27164e-01}, {-3.07506e-01, -3.07506e-01, -3.69007e+00}, + { 1.23002e+00, -1.23002e+00, 2.28018e-15}, { 3.69007e+00, -1.53753e-01, 1.84503e+00}, {-1.53753e-01, 3.69007e+00, 1.84503e+00}, + {-2.35197e+00, 2.35197e+00, -8.00513e-16}, {-2.35197e+00, -2.35197e+00, -7.83988e-01}, { 1.37903e-15, -1.46671e+00, 9.77875e-17}, + { 1.46671e+00, 1.14919e-15, 1.34475e-16} + } + }; + void SetUp() { - nylm = (lmax + 1) * (lmax + 1); ylm.create(nylm,ng); g = new ModuleBase::Vector3[ng]; g[0].set(1.0,0.0,0.0); @@ -86,7 +195,10 @@ class YlmRealTest : public testing::Test g[3].set(-1.0,-1.0,-1.0); rly = new double[nylm]; - ref = new double[nylm*ng]{ + rlyvector.resize(nylm); + rlgy = new double[nylm][3]; + rlgyvector.resize(nylm,std::vector(3)); + ref = new double[64*4]{ y00(g[0].x, g[0].y, g[0].z), y00(g[1].x, g[1].y, g[1].z), y00(g[2].x, g[2].y, g[2].z), y00(g[3].x, g[3].y, g[3].z), y10(g[0].x, g[0].y, g[0].z), y10(g[1].x, g[1].y, g[1].z), y10(g[2].x, g[2].y, g[2].z), y10(g[3].x, g[3].y, g[3].z), y11(g[0].x, g[0].y, g[0].z), y11(g[1].x, g[1].y, g[1].z), y11(g[2].x, g[2].y, g[2].z), y11(g[3].x, g[3].y, g[3].z), @@ -152,8 +264,6 @@ class YlmRealTest : public testing::Test -0.707162732524596, 0.000000000000000, -0.000000000000000, 0.120972027847095, -0.000000000000000, 0.707162732524596, -0.000000000000000, -0.120972027847095 } ; - - } void TearDown() @@ -161,6 +271,7 @@ class YlmRealTest : public testing::Test delete [] g; delete [] ref; delete [] rly; + delete [] rlgy; } }; @@ -171,7 +282,7 @@ TEST_F(YlmRealTest,YlmReal) { for(int j=0;j &vec, double ylmr[]); - // (2) for check + /** + * @brief Get the ylm real object and the gradient + * + * @param Lmax [in] maximum angular quantum number + l + * @param vec [in] the vector to be calculated + * @param ylmr [out] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + * @param dylmdr [out] gradient of Ylm, [dY00/dx, dY00/dy, dY00/dz], [dY10/dx, dY10/dy, dY10/dz], [dY11/dx, dY11/dy, dY11/dz],... + */ static void get_ylm_real( const int &Lmax , const ModuleBase::Vector3 &vec, double ylmr[], double dylmdr[][3]); - // (3) not used anymore. + /** + * @brief Get the ylm real (solid) object (not used anymore) + * + * @param Lmax [in] maximum angular quantum number + l + * @param x [in] x + * @param y [in] y + * @param z [in] z + * @param rly [in] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + */ static void rlylm( - const int& Lmax, // max momentum of l +1 + const int& Lmax, const double& x, const double& y, const double& z, double rly[]); - // (4) not used anymore. + /** + * @brief Get the ylm real (solid) object and the gradient (not used anymore) + * + * @param Lmax [in] maximum angular quantum number + 1 + * @param x [in] x + * @param y [in] y + * @param z [in] z + * @param rly [in] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + * @param grly [out] gradient of Ylm, [dY00/dx, dY00/dy, dY00/dz], [dY10/dx, dY10/dy, dY10/dz], [dY11/dx, dY11/dy, dY11/dz],... + */ static void rlylm( - const int& Lmax, // max momentum of l +1 + const int& Lmax, const double& x, const double& y, const double& z, double rly[], double grly[][3]); - - // (5) used in grid integration. + + /** + * @brief Get the ylm real object (used in grid integration) + * + * @param Lmax [in] maximum angular quantum number + * @param xdr [in] x/r + * @param ydr [in] y/r + * @param zdr [in] z/r + * @param rly [in] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + */ static void sph_harm( const int& Lmax, const double& xdr, @@ -53,8 +92,17 @@ class Ylm const double& zdr, std::vector &rly); - // (6) used in getting overlap. - // Peize Lin change rly 2016-08-26 + /** + * @brief Get the ylm real object (used in getting overlap) + * + * @param Lmax [in] maximum angular quantum number + * @param x [in] x/r + * @param y [in] y/r + * @param z [in] z/r + * @param rly [in] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + * @author Peize Lin + * @date 2016-08-26 + */ static void rl_sph_harm( const int& Lmax, const double& x, @@ -62,8 +110,16 @@ class Ylm const double& z, std::vector& rly); - // (6) used in getting derivative of overlap. - // Peize Lin change rly, grly 2016-08-26 + /** + * @brief Get the ylm real object and the gradient (used in getting derivative of overlap) + * + * @param Lmax [in] maximum angular quantum number + * @param x [in] x/r + * @param y [in] y/r + * @param z [in] z/r + * @param rly [in] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + * @param grly [out] gradient of Ylm, [dY00/dx, dY00/dy, dY00/dz], [dY10/dx, dY10/dy, dY10/dz], [dY11/dx, dY11/dy, dY11/dz],... + */ static void grad_rl_sph_harm( const int& Lmax, const double& x, @@ -71,14 +127,15 @@ class Ylm const double& z, std::vector& rly, std::vector>& grly); - + + //calculate the coefficient of Ylm, ylmcoef. static void set_coefficients (); static std::vector ylmcoef; static void test(); static void test1(); static void test2(); - + //set the first n elements of u to be 0.0 static void ZEROS(double u[], const int& n); private: From 7e17d276caafbeb0df8b8c579971635bd952c9a9 Mon Sep 17 00:00:00 2001 From: qianrui Date: Thu, 24 Feb 2022 15:10:07 +0800 Subject: [PATCH 37/52] < refactor > move some functions into en_solver module Finally, we want to move all calculations about FP into en_solver/FP We only finish part of it. --- source/Makefile | 2 + source/driver.cpp | 12 +- .../module_ensolver/FP/KSDFT/PW/ks_scf_pw.cpp | 188 ++++++++++++++++++ .../module_ensolver/FP/KSDFT/PW/ks_scf_pw.h | 60 ++++++ source/module_ensolver/FP/KSDFT/ks_scf.cpp | 0 source/module_ensolver/FP/KSDFT/ks_scf.h | 14 ++ source/module_ensolver/FP/OFDFT/ofdft.h | 8 + source/module_ensolver/FP/ab_initio.cpp | 0 source/module_ensolver/FP/ab_initio.h | 14 ++ source/module_ensolver/Makefile.ensolver | 13 ++ source/module_ensolver/en_solver.cpp | 38 ++++ source/module_ensolver/en_solver.h | 37 ++++ source/module_md/FIRE.cpp | 4 +- source/module_md/FIRE.h | 2 +- source/module_md/Langevin.cpp | 4 +- source/module_md/Langevin.h | 2 +- source/module_md/MD_func.cpp | 9 +- source/module_md/MD_func.h | 3 +- source/module_md/MSST.cpp | 5 +- source/module_md/MSST.h | 2 +- source/module_md/NVE.cpp | 4 +- source/module_md/NVE.h | 2 +- source/module_md/NVT_ADS.cpp | 4 +- source/module_md/NVT_ADS.h | 2 +- source/module_md/NVT_NHC.cpp | 4 +- source/module_md/NVT_NHC.h | 2 +- source/module_md/run_md_classic.cpp | 6 +- source/module_md/verlet.cpp | 5 +- source/module_md/verlet.h | 3 +- source/run_lcao.cpp | 4 +- source/run_lcao.h | 3 +- source/run_pw.cpp | 77 +------ source/run_pw.h | 3 +- source/src_ions/Cell_PW.cpp | 73 +------ source/src_ions/Cell_PW.h | 3 +- source/src_ions/ions.cpp | 42 ++-- source/src_ions/ions.h | 9 +- source/src_lcao/run_md_lcao.cpp | 11 +- source/src_lcao/run_md_lcao.h | 8 +- source/src_pw/run_md_pw.cpp | 96 ++------- source/src_pw/run_md_pw.h | 9 +- 41 files changed, 486 insertions(+), 301 deletions(-) create mode 100644 source/module_ensolver/FP/KSDFT/PW/ks_scf_pw.cpp create mode 100644 source/module_ensolver/FP/KSDFT/PW/ks_scf_pw.h create mode 100644 source/module_ensolver/FP/KSDFT/ks_scf.cpp create mode 100644 source/module_ensolver/FP/KSDFT/ks_scf.h create mode 100644 source/module_ensolver/FP/OFDFT/ofdft.h create mode 100644 source/module_ensolver/FP/ab_initio.cpp create mode 100644 source/module_ensolver/FP/ab_initio.h create mode 100644 source/module_ensolver/Makefile.ensolver create mode 100644 source/module_ensolver/en_solver.cpp create mode 100644 source/module_ensolver/en_solver.h diff --git a/source/Makefile b/source/Makefile index d7dcf8e6bee..d07fb6c6491 100644 --- a/source/Makefile +++ b/source/Makefile @@ -25,6 +25,8 @@ VPATH=./src_global\ :./src_ri\ :./\ +include module_ensolver/Makefile.ensolver + #========================== # Define HONG #========================== diff --git a/source/driver.cpp b/source/driver.cpp index 6128759c1d9..4a8fe8af3fe 100644 --- a/source/driver.cpp +++ b/source/driver.cpp @@ -85,21 +85,27 @@ void Driver::reading(void) void Driver::atomic_world(void) { ModuleBase::TITLE("Driver","atomic_world"); - //-------------------------------------------------- // choose basis sets: // pw: plane wave basis set // lcao_in_pw: LCAO expaned by plane wave basis set // lcao: linear combination of atomic orbitals //-------------------------------------------------- + string use_ensol; + ModuleEnSover::En_Solver *p_ensolver; if(GlobalV::BASIS_TYPE=="pw" || GlobalV::BASIS_TYPE=="lcao_in_pw") { - Run_pw::plane_wave_line(); + use_ensol = "ksdft_pw"; + //We set it temporarily + //Finally, we have ksdft_pw, ksdft_lcao, sdft_pw, ofdft, lj, eam, etc. + ModuleEnSover::init_esolver(p_ensolver, use_ensol); + Run_pw::plane_wave_line(p_ensolver); + ModuleEnSover::clean_esolver(p_ensolver); } #ifdef __LCAO else if(GlobalV::BASIS_TYPE=="lcao") { - Run_lcao::lcao_line(); + Run_lcao::lcao_line(p_ensolver); } #endif diff --git a/source/module_ensolver/FP/KSDFT/PW/ks_scf_pw.cpp b/source/module_ensolver/FP/KSDFT/PW/ks_scf_pw.cpp new file mode 100644 index 00000000000..11b3743fe95 --- /dev/null +++ b/source/module_ensolver/FP/KSDFT/PW/ks_scf_pw.cpp @@ -0,0 +1,188 @@ +#include "ks_scf_pw.h" + +//--------------temporary---------------------------- +#include "../../../../src_pw/global.h" +#include "../../../../module_base/global_function.h" +#include "../../../../module_symmetry/symmetry.h" +#include "../../../../src_pw/vdwd2.h" +#include "../../../../src_pw/vdwd3.h" +#include "../../../../src_pw/vdwd2_parameters.h" +#include "../../../../src_pw/vdwd3_parameters.h" +#include "../../../../src_pw/pw_complement.h" +#include "../../../../src_pw/pw_basis.h" +#include "../../../../src_io/print_info.h" +//-----force------------------- +#include "../../../../src_pw/forces.h" +//-----stress------------------ +#include "../../../../src_pw/stress_pw.h" +//--------------------------------------------------- + +namespace ModuleEnSover +{ + +void KS_SCF_PW::Init(Input &inp, UnitCell_pseudo &ucell) +{ + // setup GlobalV::NBANDS + // Yu Liu add 2021-07-03 + GlobalC::CHR.cal_nelec(); + + // mohan add 2010-09-06 + // Yu Liu move here 2021-06-27 + // because the number of element type + // will easily be ignored, so here + // I warn the user again for each type. + for(int it=0; itph2e=new H2E_SDFT(); + // } + // else + // { + // this->ph2e=new H2E_PW(); + // } + // this->pes= new Estate_PW(); + // this->phamilt=new Hamilt_PW(); + // } + + // Basis_PW basis_pw; + // Init(Inputs &inp, Cell &cel) + // { + + // basis_pw.init(inp, cel); + + // pes->init(inp, cel, basis_pw); + + // phamilt->init(bas); + // phamilt->initpot(cel, pes); + + // ph2e->init(h, pes); + // } +}; +} +#endif \ No newline at end of file diff --git a/source/module_ensolver/FP/KSDFT/ks_scf.cpp b/source/module_ensolver/FP/KSDFT/ks_scf.cpp new file mode 100644 index 00000000000..e69de29bb2d diff --git a/source/module_ensolver/FP/KSDFT/ks_scf.h b/source/module_ensolver/FP/KSDFT/ks_scf.h new file mode 100644 index 00000000000..fd3acbe20c3 --- /dev/null +++ b/source/module_ensolver/FP/KSDFT/ks_scf.h @@ -0,0 +1,14 @@ +#include "../ab_initio.h" +// #include "estates.h" +// #include "h2e.h" +namespace ModuleEnSover +{ + +class KS_SCF: public ab_initio +{ + public: + // Psi *p_wf; + // Estate *p_es; + // H2E *p_h2e; +}; +} \ No newline at end of file diff --git a/source/module_ensolver/FP/OFDFT/ofdft.h b/source/module_ensolver/FP/OFDFT/ofdft.h new file mode 100644 index 00000000000..39b2fae4fa1 --- /dev/null +++ b/source/module_ensolver/FP/OFDFT/ofdft.h @@ -0,0 +1,8 @@ +#include "../ab_initio.h" +namespace ModuleEnSover +{ +class OFDFT: public ab_initio +{ + +}; +} \ No newline at end of file diff --git a/source/module_ensolver/FP/ab_initio.cpp b/source/module_ensolver/FP/ab_initio.cpp new file mode 100644 index 00000000000..e69de29bb2d diff --git a/source/module_ensolver/FP/ab_initio.h b/source/module_ensolver/FP/ab_initio.h new file mode 100644 index 00000000000..92f371bfc64 --- /dev/null +++ b/source/module_ensolver/FP/ab_initio.h @@ -0,0 +1,14 @@ +#ifndef AB_INITIO_H +#define AB_INITIO_H +#include "../en_solver.h" +// #include "hamilt.h" +namespace ModuleEnSover +{ +class ab_initio: public En_Solver +{ + public: + // Hamilt* phamilt; +}; +} + +#endif \ No newline at end of file diff --git a/source/module_ensolver/Makefile.ensolver b/source/module_ensolver/Makefile.ensolver new file mode 100644 index 00000000000..ccc4e33cdfd --- /dev/null +++ b/source/module_ensolver/Makefile.ensolver @@ -0,0 +1,13 @@ +VPATH:=$(VPATH)\ +:./module_ensolver\ +:./module_ensolver/FP\ +:./module_ensolver/FP/KSDFT\ +:./module_ensolver/FP/KSDFT/PW + +OBJS_ENSOLVER=en_solver.o\ +ab_initio.o\ +ks_scf.o\ +ks_scf_pw.o + +OBJS_FIRST_PRINCIPLES:= ${OBJS_FIRST_PRINCIPLES} \ +${OBJS_ENSOLVER} \ No newline at end of file diff --git a/source/module_ensolver/en_solver.cpp b/source/module_ensolver/en_solver.cpp new file mode 100644 index 00000000000..e8458240687 --- /dev/null +++ b/source/module_ensolver/en_solver.cpp @@ -0,0 +1,38 @@ +#include "en_solver.h" +#include "FP/KSDFT/PW/ks_scf_pw.h" +#include "FP/OFDFT/ofdft.h" +#include "stdio.h" +namespace ModuleEnSover +{ +void En_Solver:: printag() +{ + std::cout<* vel); static void force_virial( + ModuleEnSover::En_Solver *p_ensolver, const int &istep, const MD_parameters &mdp, const UnitCell_pseudo &unit_in, diff --git a/source/module_md/MSST.cpp b/source/module_md/MSST.cpp index cc0f79a782d..b339393950b 100644 --- a/source/module_md/MSST.cpp +++ b/source/module_md/MSST.cpp @@ -4,6 +4,7 @@ #include "mpi.h" #endif #include "../module_base/timer.h" +#include "module_ensolver/en_solver.h" MSST::MSST(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in) : Verlet(MD_para_in, unit_in) { @@ -32,14 +33,14 @@ MSST::~MSST() delete []old_v; } -void MSST::setup() +void MSST::setup(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("MSST", "setup"); ModuleBase::timer::tick("MSST", "setup"); int sd = mdp.direction; - MD_func::force_virial(step_, mdp, ucell, potential, force, virial); + MD_func::force_virial(p_ensolver, step_, mdp, ucell, potential, force, virial); MD_func::kinetic_stress(ucell, vel, allmass, kinetic, stress); stress += virial; diff --git a/source/module_md/MSST.h b/source/module_md/MSST.h index dfee7a58205..36f17c0dad1 100644 --- a/source/module_md/MSST.h +++ b/source/module_md/MSST.h @@ -9,7 +9,7 @@ class MSST : public Verlet MSST(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in); ~MSST(); - void setup(); + void setup(ModuleEnSover::En_Solver *p_ensolve); void first_half(); void second_half(); void outputMD(); diff --git a/source/module_md/NVE.cpp b/source/module_md/NVE.cpp index 6bfb8aab0e3..b353b4ceccd 100644 --- a/source/module_md/NVE.cpp +++ b/source/module_md/NVE.cpp @@ -6,12 +6,12 @@ NVE::NVE(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in) : Verlet(MD_para_i NVE::~NVE(){} -void NVE::setup() +void NVE::setup(ModuleEnSover::En_Solver *p_ensolve) { ModuleBase::TITLE("NVE", "setup"); ModuleBase::timer::tick("NVE", "setup"); - Verlet::setup(); + Verlet::setup(p_ensolve); ModuleBase::timer::tick("NVE", "setup"); } diff --git a/source/module_md/NVE.h b/source/module_md/NVE.h index a71339ffebd..b3ac6bfeb37 100644 --- a/source/module_md/NVE.h +++ b/source/module_md/NVE.h @@ -9,7 +9,7 @@ class NVE : public Verlet NVE(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in); ~NVE(); - void setup(); + void setup(ModuleEnSover::En_Solver *p_ensolve); void first_half(); void second_half(); void outputMD(); diff --git a/source/module_md/NVT_ADS.cpp b/source/module_md/NVT_ADS.cpp index d891eecbcdb..fbe79aadd07 100644 --- a/source/module_md/NVT_ADS.cpp +++ b/source/module_md/NVT_ADS.cpp @@ -15,12 +15,12 @@ NVT_ADS::NVT_ADS(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in) : Verlet(M NVT_ADS::~NVT_ADS(){} -void NVT_ADS::setup() +void NVT_ADS::setup(ModuleEnSover::En_Solver *p_ensolve) { ModuleBase::TITLE("NVT_ADS", "setup"); ModuleBase::timer::tick("NVT_ADS", "setup"); - Verlet::setup(); + Verlet::setup(p_ensolve); ModuleBase::timer::tick("NVT_ADS", "setup"); } diff --git a/source/module_md/NVT_ADS.h b/source/module_md/NVT_ADS.h index 8870065d80f..b5b01bccbe1 100644 --- a/source/module_md/NVT_ADS.h +++ b/source/module_md/NVT_ADS.h @@ -9,7 +9,7 @@ class NVT_ADS : public Verlet NVT_ADS(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in); ~NVT_ADS(); - void setup(); + void setup(ModuleEnSover::En_Solver *p_ensolve); void first_half(); void second_half(); void outputMD(); diff --git a/source/module_md/NVT_NHC.cpp b/source/module_md/NVT_NHC.cpp index 53bd7dc7194..96f02d1f3f1 100644 --- a/source/module_md/NVT_NHC.cpp +++ b/source/module_md/NVT_NHC.cpp @@ -46,12 +46,12 @@ NVT_NHC::~NVT_NHC() delete []veta; } -void NVT_NHC::setup() +void NVT_NHC::setup(ModuleEnSover::En_Solver *p_ensolve) { ModuleBase::TITLE("NVT_NHC", "setup"); ModuleBase::timer::tick("NVT_NHC", "setup"); - Verlet::setup(); + Verlet::setup(p_ensolve); temp_target(); diff --git a/source/module_md/NVT_NHC.h b/source/module_md/NVT_NHC.h index c0bf3d51781..e38b3a3b64f 100644 --- a/source/module_md/NVT_NHC.h +++ b/source/module_md/NVT_NHC.h @@ -9,7 +9,7 @@ class NVT_NHC : public Verlet NVT_NHC(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in); ~NVT_NHC(); - void setup(); + void setup(ModuleEnSover::En_Solver *p_ensolve); void first_half(); void second_half(); void outputMD(); diff --git a/source/module_md/run_md_classic.cpp b/source/module_md/run_md_classic.cpp index 3e693976b9c..0ecc9521a3c 100644 --- a/source/module_md/run_md_classic.cpp +++ b/source/module_md/run_md_classic.cpp @@ -9,6 +9,7 @@ #include "../input.h" #include "../src_io/print_info.h" #include "../module_base/timer.h" +#include "module_ensolver/en_solver.h" Run_MD_CLASSIC::Run_MD_CLASSIC(){} @@ -18,6 +19,7 @@ void Run_MD_CLASSIC::classic_md_line(void) { ModuleBase::TITLE("Run_MD_CLASSIC", "classic_md_line"); ModuleBase::timer::tick("Run_MD_CLASSIC", "classic_md_line"); + ModuleEnSover::En_Solver* p_ensolver; //qianrui add it temporarily // Setup the unitcell. #ifdef __LCAO @@ -59,14 +61,14 @@ void Run_MD_CLASSIC::classic_md_line(void) { if(verlet->step_ == 0) { - verlet->setup(); + verlet->setup(p_ensolver); } else { verlet->first_half(); // update force and virial due to the update of atom positions - MD_func::force_virial(verlet->step_, verlet->mdp, verlet->ucell, verlet->potential, verlet->force, verlet->virial); + MD_func::force_virial(p_ensolver, verlet->step_, verlet->mdp, verlet->ucell, verlet->potential, verlet->force, verlet->virial); verlet->second_half(); diff --git a/source/module_md/verlet.cpp b/source/module_md/verlet.cpp index 239a6a13c34..5d8a2c3be09 100644 --- a/source/module_md/verlet.cpp +++ b/source/module_md/verlet.cpp @@ -4,6 +4,7 @@ #include "mpi.h" #endif #include "../module_base/timer.h" +#include "module_ensolver/en_solver.h" Verlet::Verlet(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in): mdp(MD_para_in), @@ -46,14 +47,14 @@ Verlet::~Verlet() delete []force; } -void Verlet::setup() +void Verlet::setup(ModuleEnSover::En_Solver *p_ensolver) { if(mdp.rstMD) { restart(); } - MD_func::force_virial(step_, mdp, ucell, potential, force, virial); + MD_func::force_virial(p_ensolver, step_, mdp, ucell, potential, force, virial); MD_func::kinetic_stress(ucell, vel, allmass, kinetic, stress); stress += virial; diff --git a/source/module_md/verlet.h b/source/module_md/verlet.h index bdca2da825f..b96e1d20565 100644 --- a/source/module_md/verlet.h +++ b/source/module_md/verlet.h @@ -4,6 +4,7 @@ #include "MD_parameters.h" #include "../module_cell/unitcell_pseudo.h" #include "../module_base/matrix.h" +#include "module_ensolver/en_solver.h" class Verlet { @@ -11,7 +12,7 @@ class Verlet Verlet(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in); virtual ~Verlet(); - virtual void setup(); + virtual void setup(ModuleEnSover::En_Solver *p_ensolver); virtual void first_half(); virtual void second_half(); virtual void outputMD(); diff --git a/source/run_lcao.cpp b/source/run_lcao.cpp index dad53cf0cde..d6cfe1927a3 100644 --- a/source/run_lcao.cpp +++ b/source/run_lcao.cpp @@ -17,7 +17,7 @@ Run_lcao::Run_lcao(){} Run_lcao::~Run_lcao(){} -void Run_lcao::lcao_line(void) +void Run_lcao::lcao_line(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_lcao","lcao_line"); ModuleBase::timer::tick("Run_lcao","lcao_line"); @@ -181,7 +181,7 @@ void Run_lcao::lcao_line(void) if(GlobalV::CALCULATION=="md") { Run_MD_LCAO run_md_lcao; - run_md_lcao.opt_cell(); + run_md_lcao.opt_cell(p_ensolver); } else // cell relaxations { diff --git a/source/run_lcao.h b/source/run_lcao.h index b3c21ca176f..9b1e2adeed4 100644 --- a/source/run_lcao.h +++ b/source/run_lcao.h @@ -8,6 +8,7 @@ #include "module_base/global_function.h" #include "module_base/global_variable.h" #include "input.h" +#include "module_ensolver/en_solver.h" class Run_lcao { @@ -18,7 +19,7 @@ class Run_lcao ~Run_lcao(); // perform Linear Combination of Atomic Orbitals (LCAO) calculations - static void lcao_line(void); + static void lcao_line(ModuleEnSover::En_Solver *p_ensolver); }; diff --git a/source/run_pw.cpp b/source/run_pw.cpp index 6b88884f384..cd526afb24b 100644 --- a/source/run_pw.cpp +++ b/source/run_pw.cpp @@ -1,21 +1,18 @@ #include "run_pw.h" #include "src_pw/global.h" -#include "src_pw/energy.h" -#include "input.h" -#include "src_io/optical.h" #include "src_io/cal_test.h" #include "src_io/winput.h" +#include "src_io/optical.h" #include "src_io/numerical_basis.h" #include "src_io/numerical_descriptor.h" #include "src_io/print_info.h" -#include "module_symmetry/symmetry.h" #include "src_ions/Cell_PW.h" #include "src_pw/run_md_pw.h" Run_pw::Run_pw(){} Run_pw::~Run_pw(){} -void Run_pw::plane_wave_line(void) +void Run_pw::plane_wave_line(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_pw","plane_wave_line"); ModuleBase::timer::tick("Run_pw","plane_wave_line"); @@ -28,44 +25,6 @@ void Run_pw::plane_wave_line(void) #else GlobalC::ucell.setup_cell( GlobalV::global_pseudo_dir, GlobalV::global_atom_card, GlobalV::ofs_running); #endif - //GlobalC::ucell.setup_cell( GlobalV::global_pseudo_dir , GlobalV::global_atom_card , GlobalV::ofs_running, GlobalV::NLOCAL, GlobalV::NBANDS); - - // setup GlobalV::NBANDS - // Yu Liu add 2021-07-03 - GlobalC::CHR.cal_nelec(); - - // mohan add 2010-09-06 - // Yu Liu move here 2021-06-27 - // because the number of element type - // will easily be ignored, so here - // I warn the user again for each type. - for(int it=0; itInit(INPUT, GlobalC::ucell); - // Calculate Structure factor - GlobalC::pw.setup_structure_factor(); - // cout<<"after pgrid init nrxx = "<Run(istep,GlobalC::ucell); + p_ensolver->cal_Energy(GlobalC::en); eiter = elec.iter; #ifdef __LCAO } @@ -178,12 +171,6 @@ void Ions::opt_ions_pw(void) elec.non_self_consistent(istep-1); eiter = elec.iter; } - // mohan added 2021-01-28, perform stochastic calculations - else if(GlobalV::CALCULATION=="scf-sto" || GlobalV::CALCULATION=="relax-sto" || GlobalV::CALCULATION=="md-sto") - { - elec_sto.scf_stochastic(istep-1); - eiter = elec_sto.iter; - } if(GlobalC::pot.out_potential == 2) { @@ -200,7 +187,7 @@ void Ions::opt_ions_pw(void) if (GlobalV::CALCULATION=="scf" || GlobalV::CALCULATION=="relax" || GlobalV::CALCULATION=="cell-relax") { - stop = this->after_scf(istep, force_step, stress_step); // pengfei Li 2018-05-14 + stop = this->after_scf(p_ensolver, istep, force_step, stress_step); // pengfei Li 2018-05-14 } time_t fend = time(NULL); @@ -251,20 +238,20 @@ void Ions::opt_ions_pw(void) return; } -bool Ions::after_scf(const int &istep, int &force_step, int &stress_step) +bool Ions::after_scf(ModuleEnSover::En_Solver *p_ensolver, const int &istep, int &force_step, int &stress_step) { ModuleBase::TITLE("Ions","after_scf"); //calculate and gather all parts of total ionic forces ModuleBase::matrix force; if(GlobalV::FORCE) { - this->gather_force_pw(force); + this->gather_force_pw(p_ensolver, force); } //calculate and gather all parts of stress ModuleBase::matrix stress; if(GlobalV::STRESS) { - this->gather_stress_pw(stress); + this->gather_stress_pw(p_ensolver, stress); } //stop in last step if(istep==GlobalV::NSTEP) @@ -299,18 +286,21 @@ bool Ions::after_scf(const int &istep, int &force_step, int &stress_step) return 1; } -void Ions::gather_force_pw(ModuleBase::matrix &force) +void Ions::gather_force_pw(ModuleEnSover::En_Solver *p_ensolver, ModuleBase::matrix &force) { ModuleBase::TITLE("Ions","gather_force_pw"); - Forces fcs; - fcs.init(force); + // Forces fcs; + // fcs.init(force); + p_ensolver->cal_Force(force); } -void Ions::gather_stress_pw(ModuleBase::matrix& stress) + +void Ions::gather_stress_pw(ModuleEnSover::En_Solver *p_ensolver, ModuleBase::matrix& stress) { ModuleBase::TITLE("Ions","gather_stress_pw"); //basic stress - Stress_PW ss; - ss.cal_stress(stress); + // Stress_PW ss; + // ss.cal_stress(stress); + p_ensolver->cal_Stress(stress); //external stress double unit_transform = 0.0; unit_transform = ModuleBase::RYDBERG_SI / pow(ModuleBase::BOHR_RADIUS_SI,3) * 1.0e-8; diff --git a/source/src_ions/ions.h b/source/src_ions/ions.h index f2b0aac6483..3baba13652c 100644 --- a/source/src_ions/ions.h +++ b/source/src_ions/ions.h @@ -9,6 +9,7 @@ #include "../src_pw/sto_elec.h" //mohan added 2021-01-28 #include "ions_move_methods.h" #include "lattice_change_methods.h" +#include "module_ensolver/en_solver.h" class Ions { @@ -18,7 +19,7 @@ class Ions Ions(){}; ~Ions(){}; - void opt_ions_pw(void); + void opt_ions_pw(ModuleEnSover::En_Solver *p_ensolver); private: @@ -41,9 +42,9 @@ class Ions Lattice_Change_Methods LCM; //seperate force_stress function first - bool after_scf(const int &istep, int &force_step, int &stress_step); - void gather_force_pw(ModuleBase::matrix &force); - void gather_stress_pw(ModuleBase::matrix& stress); + bool after_scf(ModuleEnSover::En_Solver *p_ensolver,const int &istep, int &force_step, int &stress_step); + void gather_force_pw(ModuleEnSover::En_Solver *p_ensolver, ModuleBase::matrix &force); + void gather_stress_pw(ModuleEnSover::En_Solver *p_ensolver, ModuleBase::matrix& stress); bool if_do_relax(); bool if_do_cellrelax(); bool do_relax(const int& istep, int& jstep, const ModuleBase::matrix& ionic_force, const double& total_energy); diff --git a/source/src_lcao/run_md_lcao.cpp b/source/src_lcao/run_md_lcao.cpp index 2df66eb3ec3..6feef8ea74a 100644 --- a/source/src_lcao/run_md_lcao.cpp +++ b/source/src_lcao/run_md_lcao.cpp @@ -31,7 +31,7 @@ Run_MD_LCAO::Run_MD_LCAO() Run_MD_LCAO::~Run_MD_LCAO(){} -void Run_MD_LCAO::opt_cell(void) +void Run_MD_LCAO::opt_cell(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_MD_LCAO","opt_cell"); @@ -56,12 +56,12 @@ void Run_MD_LCAO::opt_cell(void) GlobalC::pot.init_pot(ion_step, GlobalC::pw.strucFac); - opt_ions(); + opt_ions(p_ensolver); return; } -void Run_MD_LCAO::opt_ions(void) +void Run_MD_LCAO::opt_ions(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_MD_LCAO","opt_ions"); ModuleBase::timer::tick("Run_MD_LCAO","opt_ions"); @@ -114,7 +114,7 @@ void Run_MD_LCAO::opt_ions(void) { if(verlet->step_ == 0) { - verlet->setup(); + verlet->setup(p_ensolver); } else { @@ -143,7 +143,7 @@ void Run_MD_LCAO::opt_ions(void) GlobalC::pot.init_pot(verlet->step_, GlobalC::pw.strucFac); // update force and virial due to the update of atom positions - MD_func::force_virial(verlet->step_, verlet->mdp, verlet->ucell, verlet->potential, verlet->force, verlet->virial); + MD_func::force_virial(p_ensolver, verlet->step_, verlet->mdp, verlet->ucell, verlet->potential, verlet->force, verlet->virial); verlet->second_half(); @@ -198,6 +198,7 @@ void Run_MD_LCAO::opt_ions(void) } void Run_MD_LCAO::md_force_virial( + ModuleEnSover::En_Solver *p_ensolver, const int &istep, const int& numIon, double &potential, diff --git a/source/src_lcao/run_md_lcao.h b/source/src_lcao/run_md_lcao.h index 8d8694e1e2e..00f3d293913 100644 --- a/source/src_lcao/run_md_lcao.h +++ b/source/src_lcao/run_md_lcao.h @@ -2,6 +2,7 @@ #define RUN_MD_LCAO_H #include "../src_pw/charge_extra.h" +#include "module_ensolver/en_solver.h" class Run_MD_LCAO { @@ -11,9 +12,10 @@ class Run_MD_LCAO Run_MD_LCAO(); ~Run_MD_LCAO(); - void opt_cell(void); - void opt_ions(void); - void md_force_virial(const int &istep, + void opt_cell(ModuleEnSover::En_Solver *p_ensolver); + void opt_ions(ModuleEnSover::En_Solver *p_ensolver); + void md_force_virial(ModuleEnSover::En_Solver *p_ensolver, + const int &istep, const int& numIon, double &potential, ModuleBase::Vector3* force, diff --git a/source/src_pw/run_md_pw.cpp b/source/src_pw/run_md_pw.cpp index 753bfd20850..7f84765c272 100644 --- a/source/src_pw/run_md_pw.cpp +++ b/source/src_pw/run_md_pw.cpp @@ -1,13 +1,5 @@ #include "run_md_pw.h" -#include "forces.h" -#include "stress_pw.h" #include "global.h" // use chr. -#include "vdwd2.h" -#include "vdwd3.h" -#include "vdwd2_parameters.h" -#include "vdwd3_parameters.h" -#include "pw_complement.h" -#include "pw_basis.h" #include "../src_ions/variable_cell.h" // mohan add 2021-02-01 #include "../module_md/MD_func.h" #include "../module_md/FIRE.h" @@ -25,7 +17,7 @@ Run_MD_PW::Run_MD_PW() Run_MD_PW::~Run_MD_PW(){} -void Run_MD_PW::md_ions_pw(void) +void Run_MD_PW::md_ions_pw(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_MD_PW", "md_ions_pw"); ModuleBase::timer::tick("Run_MD_PW", "md_ions_pw"); @@ -81,7 +73,7 @@ void Run_MD_PW::md_ions_pw(void) { if(verlet->step_ == 0) { - verlet->setup(); + verlet->setup(p_ensolver); } else { @@ -113,7 +105,7 @@ void Run_MD_PW::md_ions_pw(void) GlobalC::wf.wfcinit(); // update force and virial due to the update of atom positions - MD_func::force_virial(verlet->step_, verlet->mdp, verlet->ucell, verlet->potential, verlet->force, verlet->virial); + MD_func::force_virial(p_ensolver, verlet->step_, verlet->mdp, verlet->ucell, verlet->potential, verlet->force, verlet->virial); verlet->second_half(); @@ -159,81 +151,14 @@ void Run_MD_PW::md_ions_pw(void) return; } -void Run_MD_PW::md_cells_pw() +void Run_MD_PW::md_cells_pw(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_MD_PW", "md_cells_pw"); ModuleBase::timer::tick("Run_MD_PW", "md_cells_pw"); - GlobalC::wf.allocate(GlobalC::kv.nks); - - GlobalC::UFFT.allocate(); - - //======================= - // init pseudopotential - //======================= - GlobalC::ppcell.init(GlobalC::ucell.ntype); - - //===================== - // init hamiltonian - // only allocate in the beginning of ELEC LOOP! - //===================== - GlobalC::hm.hpw.allocate(GlobalC::wf.npwx, GlobalV::NPOL, GlobalC::ppcell.nkb, GlobalC::pw.nrxx); - - //================================= - // initalize local pseudopotential - //================================= - GlobalC::ppcell.init_vloc(GlobalC::pw.nggm, GlobalC::ppcell.vloc); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "LOCAL POTENTIAL"); - - //====================================== - // Initalize non local pseudopotential - //====================================== - GlobalC::ppcell.init_vnl(GlobalC::ucell); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL"); - - //========================================================= - // calculate the total local pseudopotential in real space - //========================================================= - GlobalC::pot.init_pot(0, GlobalC::pw.strucFac); //atomic_rho, v_of_rho, set_vrs - - GlobalC::pot.newd(); - - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT POTENTIAL"); - - //================================================== - // create GlobalC::ppcell.tab_at , for trial wave functions. - //================================================== - GlobalC::wf.init_at_1(); - - //================================ - // Initial start wave functions - //================================ - if (GlobalV::NBANDS != 0 ) // liuyu update 2021-12-10 - { - GlobalC::wf.wfcinit(); - } -#ifdef __LCAO - switch (GlobalC::exx_global.info.hybrid_type) // Peize Lin add 2019-03-09 - { - case Exx_Global::Hybrid_Type::HF: - case Exx_Global::Hybrid_Type::PBE0: - case Exx_Global::Hybrid_Type::HSE: - GlobalC::exx_lip.init(&GlobalC::kv, &GlobalC::wf, &GlobalC::pw, &GlobalC::UFFT, &GlobalC::ucell); - break; - case Exx_Global::Hybrid_Type::No: - break; - case Exx_Global::Hybrid_Type::Generate_Matrix: - default: - throw std::invalid_argument(ModuleBase::GlobalFunc::TO_STRING(__FILE__) + ModuleBase::GlobalFunc::TO_STRING(__LINE__)); - } -#endif - - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); - // ion optimization begins // electron density optimization is included in ion optimization - - this->md_ions_pw(); + this->md_ions_pw(p_ensolver); GlobalV::ofs_running << "\n\n --------------------------------------------" << std::endl; GlobalV::ofs_running << std::setprecision(16); @@ -244,6 +169,7 @@ void Run_MD_PW::md_cells_pw() } void Run_MD_PW::md_force_virial( + ModuleEnSover::En_Solver *p_ensolver, const int &istep, const int& numIon, double &potential, @@ -327,8 +253,9 @@ void Run_MD_PW::md_force_virial( } ModuleBase::matrix fcs; - Forces ff; - ff.init(fcs); + // Forces ff; + // ff.init(fcs); + p_ensolver->cal_Force(fcs); for(int ion=0;ioncal_Stress(virial); virial = 0.5 * virial; } diff --git a/source/src_pw/run_md_pw.h b/source/src_pw/run_md_pw.h index 80f3d5ccbab..693f25e217e 100644 --- a/source/src_pw/run_md_pw.h +++ b/source/src_pw/run_md_pw.h @@ -3,6 +3,7 @@ #define RUN_MD_PW_H #include "../src_pw/charge_extra.h" +#include "module_ensolver/en_solver.h" class Run_MD_PW { @@ -10,9 +11,11 @@ class Run_MD_PW Run_MD_PW(); ~Run_MD_PW(); - void md_ions_pw(); - void md_cells_pw(); - void md_force_virial(const int &istep, + void md_ions_pw(ModuleEnSover::En_Solver *p_ensolver); + void md_cells_pw(ModuleEnSover::En_Solver *p_ensolver); + void md_force_virial( + ModuleEnSover::En_Solver *p_ensolver, + const int &istep, const int& numIon, double &potential, ModuleBase::Vector3* force, From ad3d3eff2d450c1862ffa27fdd1b41d053691af0 Mon Sep 17 00:00:00 2001 From: qianrui Date: Thu, 24 Feb 2022 16:05:04 +0800 Subject: [PATCH 38/52] add CMakeLists.txt --- CMakeLists.txt | 1 + source/CMakeLists.txt | 1 + source/module_ensolver/CMakeLists.txt | 10 ++++++++++ 3 files changed, 12 insertions(+) create mode 100644 source/module_ensolver/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d4415a82d4..8716afbc8c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -272,6 +272,7 @@ target_link_libraries(${ABACUS_BIN_NAME} pw ri driver + en_solver -lm ) diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index e65955d5989..fb72efffb62 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(module_neighbor) add_subdirectory(module_orbital) add_subdirectory(module_md) add_subdirectory(module_deepks) +add_subdirectory(module_ensolver) add_subdirectory(src_io) add_subdirectory(src_ions) add_subdirectory(src_lcao) diff --git a/source/module_ensolver/CMakeLists.txt b/source/module_ensolver/CMakeLists.txt new file mode 100644 index 00000000000..8f9d75dfcb0 --- /dev/null +++ b/source/module_ensolver/CMakeLists.txt @@ -0,0 +1,10 @@ +add_library( + en_solver + OBJECT + en_solver.cpp + FP/ab_initio.cpp + FP/KSDFT/ks_scf.cpp + FP/KSDFT/PW/ks_scf_pw.cpp +) + + From 01941204ee552f6d15e42a589c725bb1a50865f9 Mon Sep 17 00:00:00 2001 From: qianrui Date: Thu, 24 Feb 2022 16:44:36 +0800 Subject: [PATCH 39/52] add autotest for pw_refactor when PR --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c9731df9b92..e7cdd42dc31 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,6 +8,7 @@ on: - ABACUS_2.2.0_beta - deepks - planewave + - pw_refactor jobs: test: From 0a975a944fd825f200e49e1487b9b207c717f2f1 Mon Sep 17 00:00:00 2001 From: qianrui Date: Thu, 24 Feb 2022 16:53:11 +0800 Subject: [PATCH 40/52] add integration test for pw_refactor when PR --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c9731df9b92..e7cdd42dc31 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,6 +8,7 @@ on: - ABACUS_2.2.0_beta - deepks - planewave + - pw_refactor jobs: test: From cca9d48cc261f140a49e79c78a0d2babd5ea24ae Mon Sep 17 00:00:00 2001 From: qianrui Date: Thu, 24 Feb 2022 17:45:49 +0800 Subject: [PATCH 41/52] modify tests in module_md --- source/module_md/test/FIRE_test.cpp | 3 ++- source/module_md/test/Langevin_test.cpp | 3 ++- source/module_md/test/MSST_test.cpp | 3 ++- source/module_md/test/NVE_test.cpp | 4 +++- source/module_md/test/NVT_ADS_test.cpp | 3 ++- source/module_md/test/NVT_NHC_test.cpp | 3 ++- 6 files changed, 13 insertions(+), 6 deletions(-) diff --git a/source/module_md/test/FIRE_test.cpp b/source/module_md/test/FIRE_test.cpp index 07e875a0bf9..5f07e3e74a4 100644 --- a/source/module_md/test/FIRE_test.cpp +++ b/source/module_md/test/FIRE_test.cpp @@ -14,7 +14,8 @@ class FIRE_test : public testing::Test Setcell::parameters(); verlet = new FIRE(INPUT.mdp, ucell); - verlet->setup(); + ModuleEnSover::En_Solver *p_ensolver; + verlet->setup(p_ensolver); } void TearDown() diff --git a/source/module_md/test/Langevin_test.cpp b/source/module_md/test/Langevin_test.cpp index 0c24f0ad5ac..c124f7ced25 100644 --- a/source/module_md/test/Langevin_test.cpp +++ b/source/module_md/test/Langevin_test.cpp @@ -14,7 +14,8 @@ class Langevin_test : public testing::Test Setcell::parameters(); verlet = new Langevin(INPUT.mdp, ucell); - verlet->setup(); + ModuleEnSover::En_Solver *p_ensolver; + verlet->setup(p_ensolver); } void TearDown() diff --git a/source/module_md/test/MSST_test.cpp b/source/module_md/test/MSST_test.cpp index 158492b5b6e..50f9bf5c430 100644 --- a/source/module_md/test/MSST_test.cpp +++ b/source/module_md/test/MSST_test.cpp @@ -14,7 +14,8 @@ class MSST_test : public testing::Test Setcell::parameters(); verlet = new MSST(INPUT.mdp, ucell); - verlet->setup(); + ModuleEnSover::En_Solver *p_ensolver; + verlet->setup(p_ensolver); } void TearDown() diff --git a/source/module_md/test/NVE_test.cpp b/source/module_md/test/NVE_test.cpp index f485a98f78f..4276be64e46 100644 --- a/source/module_md/test/NVE_test.cpp +++ b/source/module_md/test/NVE_test.cpp @@ -1,6 +1,7 @@ #include "gtest/gtest.h" #include "setcell.h" #include "module_md/NVE.h" +#include "../../module_ensolver/en_solver.h" class NVE_test : public testing::Test { @@ -14,8 +15,9 @@ class NVE_test : public testing::Test Setcell::setupcell(ucell); Setcell::parameters(); verlet = new NVE(INPUT.mdp, ucell); + ModuleEnSover::En_Solver *p_ensolver; - verlet->setup(); + verlet->setup(p_ensolver); } void TearDown() diff --git a/source/module_md/test/NVT_ADS_test.cpp b/source/module_md/test/NVT_ADS_test.cpp index 461bdc6ca49..c65b675152a 100644 --- a/source/module_md/test/NVT_ADS_test.cpp +++ b/source/module_md/test/NVT_ADS_test.cpp @@ -14,7 +14,8 @@ class NVT_ADS_test : public testing::Test Setcell::parameters(); verlet = new NVT_ADS(INPUT.mdp, ucell); - verlet->setup(); + ModuleEnSover::En_Solver *p_ensolver; + verlet->setup(p_ensolver); } void TearDown() diff --git a/source/module_md/test/NVT_NHC_test.cpp b/source/module_md/test/NVT_NHC_test.cpp index cdaa760480e..ea502208a81 100644 --- a/source/module_md/test/NVT_NHC_test.cpp +++ b/source/module_md/test/NVT_NHC_test.cpp @@ -14,7 +14,8 @@ class NVT_NHC_test : public testing::Test Setcell::parameters(); verlet = new NVT_NHC(INPUT.mdp, ucell); - verlet->setup(); + ModuleEnSover::En_Solver *p_ensolver; + verlet->setup(p_ensolver); } void TearDown() From fbf0d2d0928c214a6ca1c7ad9ab32dc4955fb666 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Sat, 26 Feb 2022 21:31:04 +0800 Subject: [PATCH 42/52] fix compile error with test by cmake --- source/module_orbital/test/CMakeLists.txt | 2 ++ source/src_pdiag/pdiag_basic.cpp | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/source/module_orbital/test/CMakeLists.txt b/source/module_orbital/test/CMakeLists.txt index 9c780a75266..ea1423de477 100644 --- a/source/module_orbital/test/CMakeLists.txt +++ b/source/module_orbital/test/CMakeLists.txt @@ -35,6 +35,8 @@ list(APPEND depend_files ../ORB_table_alpha.cpp ../ORB_gen_tables.cpp ../../src_lcao/center2_orb-orb11.cpp + ../../src_parallel/parallel_orbitals.cpp + ../../src_pdiag/pdiag_basic.cpp ) AddTest( TARGET orbital_equal_test diff --git a/source/src_pdiag/pdiag_basic.cpp b/source/src_pdiag/pdiag_basic.cpp index dc5ec2007b2..4184a9f748b 100644 --- a/source/src_pdiag/pdiag_basic.cpp +++ b/source/src_pdiag/pdiag_basic.cpp @@ -1,7 +1,6 @@ #include "../src_parallel/parallel_common.h" #include "src_parallel/parallel_orbitals.h" #include "src_pdiag/pdiag_double.h" -#include "../src_pw/global.h" #include "../src_io/wf_local.h" #include "../module_base/lapack_connector.h" #include "../module_base/memory.h" From aab7dca22773709795f59e1839ae6a0716280c13 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Sat, 26 Feb 2022 22:11:07 +0800 Subject: [PATCH 43/52] remove pdiag_basic.cpp --- source/Makefile.Objects | 1 - source/module_orbital/ORB_control.cpp | 528 +++++++++- source/module_orbital/ORB_control.h | 6 +- source/module_orbital/test/CMakeLists.txt | 1 - source/module_orbital/test/orb_obj/README | 1 + source/src_pdiag/CMakeLists.txt | 1 - source/src_pdiag/pdiag_basic.cpp | 1129 --------------------- source/src_pdiag/pdiag_double.cpp | 605 +++++++++++ 8 files changed, 1135 insertions(+), 1137 deletions(-) create mode 100644 source/module_orbital/test/orb_obj/README delete mode 100644 source/src_pdiag/pdiag_basic.cpp diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 44a72c0d0e2..f486266f201 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -176,7 +176,6 @@ FORCE_k.o\ parallel_orbitals.o \ global_fp.o \ pdiag_double.o \ -pdiag_basic.o \ pdiag_common.o \ diag_scalapack_gvx.o \ subgrid_oper.o \ diff --git a/source/module_orbital/ORB_control.cpp b/source/module_orbital/ORB_control.cpp index 3ea5a898977..fb2524ed2ce 100644 --- a/source/module_orbital/ORB_control.cpp +++ b/source/module_orbital/ORB_control.cpp @@ -1,6 +1,11 @@ #include "ORB_control.h" #include "ORB_gen_tables.h" #include "../module_base/timer.h" +#include "../src_parallel/parallel_common.h" +#include "../src_io/wf_local.h" +#include "../module_base/lapack_connector.h" +#include "../module_base/memory.h" + //#include "build_st_pw.h" ORB_control::ORB_control() @@ -157,6 +162,524 @@ void ORB_control::setup_2d_division(void) } +void ORB_control::set_parameters(void) +{ + ModuleBase::TITLE("ORB_control","set_parameters"); + + Parallel_Orbitals* pv = &this->ParaV; + // set loc_size + if(GlobalV::GAMMA_ONLY_LOCAL)//xiaohui add 2014-12-21 + { + pv->loc_size=GlobalV::NBANDS/GlobalV::DSIZE; + + // mohan add 2012-03-29 + if(pv->loc_size==0) + { + GlobalV::ofs_warning << " loc_size=0" << " in proc " << GlobalV::MY_RANK+1 << std::endl; + ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < GlobalV::DSIZE"); + } + + if (GlobalV::DRANKloc_size+=1; + if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"local size",pv->loc_size); + + // set loc_sizes + delete[] pv->loc_sizes; + pv->loc_sizes = new int[GlobalV::DSIZE]; + ModuleBase::GlobalFunc::ZEROS(pv->loc_sizes, GlobalV::DSIZE); + + pv->lastband_in_proc = 0; + pv->lastband_number = 0; + int count_bands = 0; + for (int i=0; iloc_sizes[i]=GlobalV::NBANDS/GlobalV::DSIZE+1; + } + else + { + pv->loc_sizes[i]=GlobalV::NBANDS/GlobalV::DSIZE; + } + count_bands += pv->loc_sizes[i]; + if (count_bands >= GlobalV::NBANDS) + { + pv->lastband_in_proc = i; + pv->lastband_number = GlobalV::NBANDS - (count_bands - pv->loc_sizes[i]); + break; + } + } + } + else + { + pv->loc_size=GlobalV::NLOCAL/GlobalV::DSIZE; + + // mohan add 2012-03-29 + if(pv->loc_size==0) + { + GlobalV::ofs_warning << " loc_size=0" << " in proc " << GlobalV::MY_RANK+1 << std::endl; + ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < GlobalV::DSIZE"); + } + + if (GlobalV::DRANKloc_size += 1; + } + if(pv->testpb) ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"local size",pv->loc_size); + + // set loc_sizes + delete[] pv->loc_sizes; + pv->loc_sizes = new int[GlobalV::DSIZE]; + ModuleBase::GlobalFunc::ZEROS(pv->loc_sizes, GlobalV::DSIZE); + + pv->lastband_in_proc = 0; + pv->lastband_number = 0; + int count_bands = 0; + for (int i=0; iloc_sizes[i]=GlobalV::NLOCAL/GlobalV::DSIZE+1; + } + else + { + pv->loc_sizes[i]=GlobalV::NLOCAL/GlobalV::DSIZE; + } + count_bands += pv->loc_sizes[i]; + if (count_bands >= GlobalV::NBANDS) + { + pv->lastband_in_proc = i; + pv->lastband_number = GlobalV::NBANDS - (count_bands - pv->loc_sizes[i]); + break; + } + } + }//xiaohui add 2014-12-21 + + if (GlobalV::KS_SOLVER=="hpseps") //LiuXh add 2021-09-06, clear memory, Z_LOC only used in hpseps solver + { + pv->Z_LOC = new double*[GlobalV::NSPIN]; + for(int is=0; isZ_LOC[is] = new double[pv->loc_size * GlobalV::NLOCAL]; + ModuleBase::GlobalFunc::ZEROS(pv->Z_LOC[is], pv->loc_size * GlobalV::NLOCAL); + } + pv->alloc_Z_LOC = true;//xiaohui add 2014-12-22 + } + + if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"lastband_in_proc", pv->lastband_in_proc); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"lastband_number", pv->lastband_number); + + return; +} + + +#ifdef __MPI +// creat the 'comm_2D' stratege. +void ORB_control::mpi_creat_cart(MPI_Comm *comm_2D, int prow, int pcol) +{ + ModuleBase::TITLE("ORB_control","mpi_creat_cart"); + // the matrix is divided as ( dim[0] * dim[1] ) + int dim[2]; + int period[2]={1,1}; + int reorder=0; + dim[0]=prow; + dim[1]=pcol; + + if(this->ParaV.testpb)GlobalV::ofs_running << " dim = " << dim[0] << " * " << dim[1] << std::endl; + + MPI_Cart_create(DIAG_WORLD,2,dim,period,reorder,comm_2D); + return; +} +#endif + +#ifdef __MPI +void ORB_control::mat_2d(MPI_Comm vu, + const int &M_A, + const int &N_A, + const int &nb, + LocalMatrix &LM) +{ + ModuleBase::TITLE("ORB_control", "mat_2d"); + + Parallel_Orbitals* pv = &this->ParaV; + + int dim[2]; + int period[2]; + int coord[2]; + int i,j,k,end_id; + int block; + + // (0) every processor get it's id on the 2D comm + // : ( coord[0], coord[1] ) + MPI_Cart_get(vu,2,dim,period,coord); + + // (1.1) how many blocks at least + // eg. M_A = 6400, nb = 64; + // so block = 10; + block=M_A/nb; + + // (1.2) If data remain, add 1. + if (block*nbtestpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Total Row Blocks Number",block); + + // mohan add 2010-09-12 + if(dim[0]>block) + { + GlobalV::ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; + GlobalV::ofs_warning << " but, the number of row blocks is " << block << std::endl; + ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no row blocks, try a smaller 'nb2d' parameter."); + } + + // (2.1) row_b : how many blocks for this processor. (at least) + LM.row_b=block/dim[0]; + + // (2.2) row_b : how many blocks in this processor. + // if there are blocks remain, some processors add 1. + if (coord[0]testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local Row Block Number",LM.row_b); + + // (3) end_id indicates the last block belong to + // which processor. + if (block%dim[0]==0) + { + end_id=dim[0]-1; + } + else + { + end_id=block%dim[0]-1; + } + + if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Ending Row Block in processor",end_id); + + // (4) row_num : how many rows in this processors : + // the one owns the last block is different. + if (coord[0]==end_id) + { + LM.row_num=(LM.row_b-1)*nb+(M_A-(block-1)*nb); + } + else + { + LM.row_num=LM.row_b*nb; + } + + if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local rows (including nb)",LM.row_num); + + // (5) row_set, it's a global index : + // save explicitly : every row in this processor + // belongs to which row in the global matrix. + delete[] LM.row_set; + LM.row_set= new int[LM.row_num]; + j=0; + for (i=0; iblock) + { + GlobalV::ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; + GlobalV::ofs_warning << " but, the number of column blocks is " << block << std::endl; + ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no column blocks."); + } + + LM.col_b=block/dim[1]; + if (coord[1]testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local Row Block Number",LM.col_b); + + if (block%dim[1]==0) + { + end_id=dim[1]-1; + } + else + { + end_id=block%dim[1]-1; + } + + if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Ending Row Block in processor",end_id); + + if (coord[1]==end_id) + { + LM.col_num=(LM.col_b-1)*nb+(M_A-(block-1)*nb); + } + else + { + LM.col_num=LM.col_b*nb; + } + + if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local columns (including nb)",LM.row_num); + + delete[] LM.col_set; + LM.col_set = new int[LM.col_num]; + + j=0; + for (i=0; iblock) + { + GlobalV::ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; + GlobalV::ofs_warning << " but, the number of bands-row-block is " << block << std::endl; + ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no bands-row-blocks."); + } + int col_b_bands = block / dim[1]; + if (coord[1] < block % dim[1]) + { + col_b_bands++; + } + if (block%dim[1]==0) + { + end_id=dim[1]-1; + } + else + { + end_id=block%dim[1]-1; + } + if (coord[1]==end_id) + { + pv->ncol_bands=(col_b_bands-1)*nb+(N_A-(block-1)*nb); + } + else + { + pv->ncol_bands=col_b_bands*nb; + } + pv->nloc_wfc = pv->ncol_bands * LM.row_num; + + std::cout << pv->nloc_wfc << " " << pv->ncol_bands << " " << LM.row_num << std::endl; + + return; +} +#endif + + +#ifdef __MPI +// A : contains total matrix element in processor. +void ORB_control::data_distribution( + MPI_Comm comm_2D, + const std::string &file, + const int &n, + const int &nb, + double *A, + const LocalMatrix &LM) +{ + ModuleBase::TITLE("ORB_control", "data_distribution"); + Parallel_Orbitals* pv = &this->ParaV; + MPI_Comm comm_row; + MPI_Comm comm_col; + MPI_Status status; + + int dim[2]; + int period[2]; + int coord[2]; + MPI_Cart_get(comm_2D,2,dim,period,coord); + + if(pv->testpb) GlobalV::ofs_running << "\n dim = " << dim[0] << " * " << dim[1] << std::endl; + if(pv->testpb) GlobalV::ofs_running << " coord = ( " << coord[0] << " , " << coord[1] << ")." << std::endl; + if(pv->testpb) GlobalV::ofs_running << " n = " << n << std::endl; + + mpi_sub_col(comm_2D,&comm_col); + mpi_sub_row(comm_2D,&comm_row); + + // total number of processors + const int myid = coord[0]*dim[1]+coord[1]; + + // the matrix is n * n + double* ele_val = new double[n]; + double* val = new double[n]; + int* sends = new int[dim[1]]; + int* fpt = new int[dim[1]]; + int* snd = new int[dim[1]]; + int* temp = new int[dim[1]]; + + ModuleBase::GlobalFunc::ZEROS(ele_val, n); + ModuleBase::GlobalFunc::ZEROS(val, n); + ModuleBase::GlobalFunc::ZEROS(sends, dim[1]); + ModuleBase::GlobalFunc::ZEROS(fpt, dim[1]); + ModuleBase::GlobalFunc::ZEROS(snd, dim[1]); + ModuleBase::GlobalFunc::ZEROS(temp, dim[1]); + + // the columes of matrix is divided by 'dim[1]' 'rows of processors'. + // collect all information of each 'rows of processors' + // collection data is saved in 'sends' + snd[coord[1]] = LM.col_num; + MPI_Allgather(&snd[coord[1]],1,MPI_INT,sends,1,MPI_INT,comm_row); + + // fpt : start column index after applied 'mat_2d' reorder algorithms + // to matrix. + fpt[0] = 0; + for (int i=1; iParaV; - // set loc_size - if(GlobalV::GAMMA_ONLY_LOCAL)//xiaohui add 2014-12-21 - { - pv->loc_size=GlobalV::NBANDS/GlobalV::DSIZE; - - // mohan add 2012-03-29 - if(pv->loc_size==0) - { - GlobalV::ofs_warning << " loc_size=0" << " in proc " << GlobalV::MY_RANK+1 << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < GlobalV::DSIZE"); - } - - if (GlobalV::DRANKloc_size+=1; - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"local size",pv->loc_size); - - // set loc_sizes - delete[] pv->loc_sizes; - pv->loc_sizes = new int[GlobalV::DSIZE]; - ModuleBase::GlobalFunc::ZEROS(pv->loc_sizes, GlobalV::DSIZE); - - pv->lastband_in_proc = 0; - pv->lastband_number = 0; - int count_bands = 0; - for (int i=0; iloc_sizes[i]=GlobalV::NBANDS/GlobalV::DSIZE+1; - } - else - { - pv->loc_sizes[i]=GlobalV::NBANDS/GlobalV::DSIZE; - } - count_bands += pv->loc_sizes[i]; - if (count_bands >= GlobalV::NBANDS) - { - pv->lastband_in_proc = i; - pv->lastband_number = GlobalV::NBANDS - (count_bands - pv->loc_sizes[i]); - break; - } - } - } - else - { - pv->loc_size=GlobalV::NLOCAL/GlobalV::DSIZE; - - // mohan add 2012-03-29 - if(pv->loc_size==0) - { - GlobalV::ofs_warning << " loc_size=0" << " in proc " << GlobalV::MY_RANK+1 << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < GlobalV::DSIZE"); - } - - if (GlobalV::DRANKloc_size += 1; - } - if(pv->testpb) ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"local size",pv->loc_size); - - // set loc_sizes - delete[] pv->loc_sizes; - pv->loc_sizes = new int[GlobalV::DSIZE]; - ModuleBase::GlobalFunc::ZEROS(pv->loc_sizes, GlobalV::DSIZE); - - pv->lastband_in_proc = 0; - pv->lastband_number = 0; - int count_bands = 0; - for (int i=0; iloc_sizes[i]=GlobalV::NLOCAL/GlobalV::DSIZE+1; - } - else - { - pv->loc_sizes[i]=GlobalV::NLOCAL/GlobalV::DSIZE; - } - count_bands += pv->loc_sizes[i]; - if (count_bands >= GlobalV::NBANDS) - { - pv->lastband_in_proc = i; - pv->lastband_number = GlobalV::NBANDS - (count_bands - pv->loc_sizes[i]); - break; - } - } - }//xiaohui add 2014-12-21 - - if (GlobalV::KS_SOLVER=="hpseps") //LiuXh add 2021-09-06, clear memory, Z_LOC only used in hpseps solver - { - pv->Z_LOC = new double*[GlobalV::NSPIN]; - for(int is=0; isZ_LOC[is] = new double[pv->loc_size * GlobalV::NLOCAL]; - ModuleBase::GlobalFunc::ZEROS(pv->Z_LOC[is], pv->loc_size * GlobalV::NLOCAL); - } - pv->alloc_Z_LOC = true;//xiaohui add 2014-12-22 - } - - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"lastband_in_proc", pv->lastband_in_proc); - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"lastband_number", pv->lastband_number); - - return; -} - - -#ifdef __MPI -// creat the 'comm_2D' stratege. -void ORB_control::mpi_creat_cart(MPI_Comm *comm_2D, int prow, int pcol) -{ - ModuleBase::TITLE("ORB_control","mpi_creat_cart"); - // the matrix is divided as ( dim[0] * dim[1] ) - int dim[2]; - int period[2]={1,1}; - int reorder=0; - dim[0]=prow; - dim[1]=pcol; - - if(this->ParaV.testpb)GlobalV::ofs_running << " dim = " << dim[0] << " * " << dim[1] << std::endl; - - MPI_Cart_create(DIAG_WORLD,2,dim,period,reorder,comm_2D); - return; -} -#endif - -#ifdef __MPI -void ORB_control::mat_2d(MPI_Comm vu, - const int &M_A, - const int &N_A, - const int &nb, - LocalMatrix &LM) -{ - ModuleBase::TITLE("ORB_control", "mat_2d"); - - Parallel_Orbitals* pv = &this->ParaV; - - int dim[2]; - int period[2]; - int coord[2]; - int i,j,k,end_id; - int block; - - // (0) every processor get it's id on the 2D comm - // : ( coord[0], coord[1] ) - MPI_Cart_get(vu,2,dim,period,coord); - - // (1.1) how many blocks at least - // eg. M_A = 6400, nb = 64; - // so block = 10; - block=M_A/nb; - - // (1.2) If data remain, add 1. - if (block*nbtestpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Total Row Blocks Number",block); - - // mohan add 2010-09-12 - if(dim[0]>block) - { - GlobalV::ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; - GlobalV::ofs_warning << " but, the number of row blocks is " << block << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no row blocks, try a smaller 'nb2d' parameter."); - } - - // (2.1) row_b : how many blocks for this processor. (at least) - LM.row_b=block/dim[0]; - - // (2.2) row_b : how many blocks in this processor. - // if there are blocks remain, some processors add 1. - if (coord[0]testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local Row Block Number",LM.row_b); - - // (3) end_id indicates the last block belong to - // which processor. - if (block%dim[0]==0) - { - end_id=dim[0]-1; - } - else - { - end_id=block%dim[0]-1; - } - - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Ending Row Block in processor",end_id); - - // (4) row_num : how many rows in this processors : - // the one owns the last block is different. - if (coord[0]==end_id) - { - LM.row_num=(LM.row_b-1)*nb+(M_A-(block-1)*nb); - } - else - { - LM.row_num=LM.row_b*nb; - } - - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local rows (including nb)",LM.row_num); - - // (5) row_set, it's a global index : - // save explicitly : every row in this processor - // belongs to which row in the global matrix. - delete[] LM.row_set; - LM.row_set= new int[LM.row_num]; - j=0; - for (i=0; iblock) - { - GlobalV::ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; - GlobalV::ofs_warning << " but, the number of column blocks is " << block << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no column blocks."); - } - - LM.col_b=block/dim[1]; - if (coord[1]testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local Row Block Number",LM.col_b); - - if (block%dim[1]==0) - { - end_id=dim[1]-1; - } - else - { - end_id=block%dim[1]-1; - } - - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Ending Row Block in processor",end_id); - - if (coord[1]==end_id) - { - LM.col_num=(LM.col_b-1)*nb+(M_A-(block-1)*nb); - } - else - { - LM.col_num=LM.col_b*nb; - } - - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local columns (including nb)",LM.row_num); - - delete[] LM.col_set; - LM.col_set = new int[LM.col_num]; - - j=0; - for (i=0; iblock) - { - GlobalV::ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; - GlobalV::ofs_warning << " but, the number of bands-row-block is " << block << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no bands-row-blocks."); - } - int col_b_bands = block / dim[1]; - if (coord[1] < block % dim[1]) - { - col_b_bands++; - } - if (block%dim[1]==0) - { - end_id=dim[1]-1; - } - else - { - end_id=block%dim[1]-1; - } - if (coord[1]==end_id) - { - pv->ncol_bands=(col_b_bands-1)*nb+(N_A-(block-1)*nb); - } - else - { - pv->ncol_bands=col_b_bands*nb; - } - pv->nloc_wfc = pv->ncol_bands * LM.row_num; - - std::cout << pv->nloc_wfc << " " << pv->ncol_bands << " " << LM.row_num << std::endl; - - return; -} -#endif - - -#ifdef __MPI -// A : contains total matrix element in processor. -void ORB_control::data_distribution( - MPI_Comm comm_2D, - const std::string &file, - const int &n, - const int &nb, - double *A, - const LocalMatrix &LM) -{ - ModuleBase::TITLE("ORB_control", "data_distribution"); - Parallel_Orbitals* pv = &this->ParaV; - MPI_Comm comm_row; - MPI_Comm comm_col; - MPI_Status status; - - int dim[2]; - int period[2]; - int coord[2]; - MPI_Cart_get(comm_2D,2,dim,period,coord); - - if(pv->testpb) GlobalV::ofs_running << "\n dim = " << dim[0] << " * " << dim[1] << std::endl; - if(pv->testpb) GlobalV::ofs_running << " coord = ( " << coord[0] << " , " << coord[1] << ")." << std::endl; - if(pv->testpb) GlobalV::ofs_running << " n = " << n << std::endl; - - mpi_sub_col(comm_2D,&comm_col); - mpi_sub_row(comm_2D,&comm_row); - - // total number of processors - const int myid = coord[0]*dim[1]+coord[1]; - - // the matrix is n * n - double* ele_val = new double[n]; - double* val = new double[n]; - int* sends = new int[dim[1]]; - int* fpt = new int[dim[1]]; - int* snd = new int[dim[1]]; - int* temp = new int[dim[1]]; - - ModuleBase::GlobalFunc::ZEROS(ele_val, n); - ModuleBase::GlobalFunc::ZEROS(val, n); - ModuleBase::GlobalFunc::ZEROS(sends, dim[1]); - ModuleBase::GlobalFunc::ZEROS(fpt, dim[1]); - ModuleBase::GlobalFunc::ZEROS(snd, dim[1]); - ModuleBase::GlobalFunc::ZEROS(temp, dim[1]); - - // the columes of matrix is divided by 'dim[1]' 'rows of processors'. - // collect all information of each 'rows of processors' - // collection data is saved in 'sends' - snd[coord[1]] = LM.col_num; - MPI_Allgather(&snd[coord[1]],1,MPI_INT,sends,1,MPI_INT,comm_row); - - // fpt : start column index after applied 'mat_2d' reorder algorithms - // to matrix. - fpt[0] = 0; - for (int i=1; i **cc,std::complex *Z, const int &ik) -{ - ModuleBase::TITLE("Pdiag_Double","gath_eig_complex"); - time_t time_start = time(NULL); - //GlobalV::ofs_running << " Start gath_eig_complex Time : " << ctime(&time_start); - const Parallel_Orbitals* pv = this->ParaV; - - int i,j,k; - int nprocs,myid; - MPI_Status status; - MPI_Comm_size(comm,&nprocs); - MPI_Comm_rank(comm,&myid); - - std::complex **ctot; - - // mohan add 2010-07-03 - // the occupied bands are useless - // for calculating charge density. - if(GlobalV::DRANK> pv->lastband_in_proc) - { - delete[] Z; - } - - // first we need to collect all - // the occupied bands. - // GlobalV::NBANDS * GlobalV::NLOCAL - if(GlobalV::DRANK==0) - { - ctot = new std::complex*[GlobalV::NBANDS]; - for (int i=0; i[GlobalV::NLOCAL]; - ModuleBase::GlobalFunc::ZEROS(ctot[i], GlobalV::NLOCAL); - } - ModuleBase::Memory::record("Pdiag_Double","ctot",GlobalV::NBANDS*GlobalV::NLOCAL,"cdouble"); - } - - k=0; - if (myid==0) - { - // mohan add nbnd0 2010-07-02 - int nbnd0 = -1; - if (GlobalV::NBANDS < pv->loc_sizes[0]) - { - // means all bands in this processor - // is needed ( is occupied) - nbnd0 = GlobalV::NBANDS; - } - else - { - // means this processor only save - // part of GlobalV::NBANDS. - nbnd0 = pv->loc_sizes[0]; - } - if(pv->testpb)GlobalV::ofs_running << " nbnd in processor 0 is " << nbnd0 << std::endl; - - for (i=0; iloc_sizes[0]+i]; - } - k++; - } - // Z is useless in processor 0 now. - delete[] Z; - } - MPI_Barrier(comm); - - for (i=1; i<= pv->lastband_in_proc; i++) - { - // mohan fix bug 2010-07-02 - // rows indicates the data structure of Z. - // mpi_times indicates the data distribution - // time, each time send a band. - int rows = pv->loc_sizes[i]; - int mpi_times; - if (i==pv->lastband_in_proc) - { - mpi_times = pv->lastband_number; - } - else - { - mpi_times = pv->loc_sizes[i]; - } - if(pv->testpb)GlobalV::ofs_running << " nbnd in processor " << i << " is " << mpi_times << std::endl; - if (myid==i) - { - for (j=0; j *send = new std::complex[n]; - int count = 0; - - for (int m=0; m *ctmp = new std::complex[GlobalV::NLOCAL]; - ModuleBase::GlobalFunc::ZEROS(ctmp, GlobalV::NLOCAL); - int tag = j; - - // Processor 0 receive the data from other processors. - MPI_Recv(ctmp,n,mpicomplex,i,tag,comm,&status); - - for (int m=0; mout_lowf) - { -// std::cout << " write the wave functions" << std::endl; - WF_Local::write_lowf_complex( ss.str(), ctot, ik );//mohan add 2010-09-09 - } - - // mohan add 2010-09-10 - // distribution of local wave functions - // to each processor. - WF_Local::distri_lowf_complex( ctot, cc); - - // clean staff. - if(GlobalV::DRANK==0) - { - for (int i=0; i **c,std::complex *Z) -{ - ModuleBase::TITLE("Pdiag_Double","gath_full_eig_complex"); - - time_t time_start = time(NULL); - //GlobalV::ofs_running << " Start gath_full_eig_complex Time : " << ctime(&time_start); - - int i,j,k,incx=1; - int *loc_sizes,loc_size,nprocs,myid; - MPI_Status status; - MPI_Comm_size(comm,&nprocs); - MPI_Comm_rank(comm,&myid); - loc_sizes =(int*)malloc(sizeof(int)*nprocs); - loc_size=n/nprocs; - for (i=0; iParaV; - - int i, j, k; - int nprocs,myid; - MPI_Status status; - MPI_Comm_size(comm,&nprocs); - MPI_Comm_rank(comm,&myid); - - double **ctot; - - // mohan add 2010-07-03 - // the occupied bands are useless - // for calculating charge density. - if(GlobalV::DRANK > pv->lastband_in_proc) - { - delete[] Z; - } - - // first we need to collect all - // the occupied bands. - // GlobalV::NBANDS * GlobalV::NLOCAL - if(myid==0) - { - ctot = new double*[GlobalV::NBANDS]; - for (int i=0; iloc_sizes[0]) - { - // means all bands in this processor - // is needed ( is occupied) - nbnd0 = GlobalV::NBANDS; - } - else - { - // means this processor only save - // part of GlobalV::NBANDS. - nbnd0 = pv->loc_sizes[0]; - } - if(pv->testpb) GlobalV::ofs_running << " nbnd in processor 0 is " << nbnd0 << std::endl; - -//printf("from 0 to %d\n",nbnd0-1); - for (i=0; iloc_sizes[0]+i]; - } - k++; - } - // Z is useless in processor 0 now. - delete[] Z; - } - MPI_Barrier(comm); - - - for (i=1; i<= pv->lastband_in_proc; i++) - { - // mohan fix bug 2010-07-02 - // rows indicates the data structure of Z. - // mpi_times indicates the data distribution - // time, each time send a band. - int rows = pv->loc_sizes[i]; - int mpi_times; - if (i==pv->lastband_in_proc) - { - mpi_times = pv->lastband_number; - } - else - { - mpi_times = pv->loc_sizes[i]; - } - if(pv->testpb)GlobalV::ofs_running << " nbnd in processor " << i << " is " << mpi_times << std::endl; - if (myid==i) - { - for (j=0; j1); -} -*/ -MPI_Barrier(comm); - - // mohan add 2010-09-10 - // output the wave function if required. - // this is a bad position to output wave functions. - // but it works! - if(this->out_lowf) - { - // read is in ../src_algorithms/wf_local.cpp - std::stringstream ss; - ss << GlobalV::global_out_dir << "LOWF_GAMMA_S" << GlobalV::CURRENT_SPIN+1 << ".dat"; - // mohan add 2012-04-03, because we need the occupations for the - // first iteration. - Occupy::calculate_weights(); - WF_Local::write_lowf( ss.str(), ctot );//mohan add 2010-09-09 - } - - // mohan add 2010-09-10 - // distribution of local wave functions - // to each processor. - // only used for GlobalV::GAMMA_ONLY_LOCAL - //WF_Local::distri_lowf( ctot, wfc); - - - // clean staff. - if(myid==0) - { - for (int i=0; i **cc,std::complex *Z, const int &ik) +{ + ModuleBase::TITLE("Pdiag_Double","gath_eig_complex"); + time_t time_start = time(NULL); + //GlobalV::ofs_running << " Start gath_eig_complex Time : " << ctime(&time_start); + const Parallel_Orbitals* pv = this->ParaV; + + int i,j,k; + int nprocs,myid; + MPI_Status status; + MPI_Comm_size(comm,&nprocs); + MPI_Comm_rank(comm,&myid); + + std::complex **ctot; + + // mohan add 2010-07-03 + // the occupied bands are useless + // for calculating charge density. + if(GlobalV::DRANK> pv->lastband_in_proc) + { + delete[] Z; + } + + // first we need to collect all + // the occupied bands. + // GlobalV::NBANDS * GlobalV::NLOCAL + if(GlobalV::DRANK==0) + { + ctot = new std::complex*[GlobalV::NBANDS]; + for (int i=0; i[GlobalV::NLOCAL]; + ModuleBase::GlobalFunc::ZEROS(ctot[i], GlobalV::NLOCAL); + } + ModuleBase::Memory::record("Pdiag_Double","ctot",GlobalV::NBANDS*GlobalV::NLOCAL,"cdouble"); + } + + k=0; + if (myid==0) + { + // mohan add nbnd0 2010-07-02 + int nbnd0 = -1; + if (GlobalV::NBANDS < pv->loc_sizes[0]) + { + // means all bands in this processor + // is needed ( is occupied) + nbnd0 = GlobalV::NBANDS; + } + else + { + // means this processor only save + // part of GlobalV::NBANDS. + nbnd0 = pv->loc_sizes[0]; + } + if(pv->testpb)GlobalV::ofs_running << " nbnd in processor 0 is " << nbnd0 << std::endl; + + for (i=0; iloc_sizes[0]+i]; + } + k++; + } + // Z is useless in processor 0 now. + delete[] Z; + } + MPI_Barrier(comm); + + for (i=1; i<= pv->lastband_in_proc; i++) + { + // mohan fix bug 2010-07-02 + // rows indicates the data structure of Z. + // mpi_times indicates the data distribution + // time, each time send a band. + int rows = pv->loc_sizes[i]; + int mpi_times; + if (i==pv->lastband_in_proc) + { + mpi_times = pv->lastband_number; + } + else + { + mpi_times = pv->loc_sizes[i]; + } + if(pv->testpb)GlobalV::ofs_running << " nbnd in processor " << i << " is " << mpi_times << std::endl; + if (myid==i) + { + for (j=0; j *send = new std::complex[n]; + int count = 0; + + for (int m=0; m *ctmp = new std::complex[GlobalV::NLOCAL]; + ModuleBase::GlobalFunc::ZEROS(ctmp, GlobalV::NLOCAL); + int tag = j; + + // Processor 0 receive the data from other processors. + MPI_Recv(ctmp,n,mpicomplex,i,tag,comm,&status); + + for (int m=0; mout_lowf) + { +// std::cout << " write the wave functions" << std::endl; + WF_Local::write_lowf_complex( ss.str(), ctot, ik );//mohan add 2010-09-09 + } + + // mohan add 2010-09-10 + // distribution of local wave functions + // to each processor. + WF_Local::distri_lowf_complex( ctot, cc); + + // clean staff. + if(GlobalV::DRANK==0) + { + for (int i=0; i **c,std::complex *Z) +{ + ModuleBase::TITLE("Pdiag_Double","gath_full_eig_complex"); + + time_t time_start = time(NULL); + //GlobalV::ofs_running << " Start gath_full_eig_complex Time : " << ctime(&time_start); + + int i,j,k,incx=1; + int *loc_sizes,loc_size,nprocs,myid; + MPI_Status status; + MPI_Comm_size(comm,&nprocs); + MPI_Comm_rank(comm,&myid); + loc_sizes =(int*)malloc(sizeof(int)*nprocs); + loc_size=n/nprocs; + for (i=0; iParaV; + + int i, j, k; + int nprocs,myid; + MPI_Status status; + MPI_Comm_size(comm,&nprocs); + MPI_Comm_rank(comm,&myid); + + double **ctot; + + // mohan add 2010-07-03 + // the occupied bands are useless + // for calculating charge density. + if(GlobalV::DRANK > pv->lastband_in_proc) + { + delete[] Z; + } + + // first we need to collect all + // the occupied bands. + // GlobalV::NBANDS * GlobalV::NLOCAL + if(myid==0) + { + ctot = new double*[GlobalV::NBANDS]; + for (int i=0; iloc_sizes[0]) + { + // means all bands in this processor + // is needed ( is occupied) + nbnd0 = GlobalV::NBANDS; + } + else + { + // means this processor only save + // part of GlobalV::NBANDS. + nbnd0 = pv->loc_sizes[0]; + } + if(pv->testpb) GlobalV::ofs_running << " nbnd in processor 0 is " << nbnd0 << std::endl; + +//printf("from 0 to %d\n",nbnd0-1); + for (i=0; iloc_sizes[0]+i]; + } + k++; + } + // Z is useless in processor 0 now. + delete[] Z; + } + MPI_Barrier(comm); + + + for (i=1; i<= pv->lastband_in_proc; i++) + { + // mohan fix bug 2010-07-02 + // rows indicates the data structure of Z. + // mpi_times indicates the data distribution + // time, each time send a band. + int rows = pv->loc_sizes[i]; + int mpi_times; + if (i==pv->lastband_in_proc) + { + mpi_times = pv->lastband_number; + } + else + { + mpi_times = pv->loc_sizes[i]; + } + if(pv->testpb)GlobalV::ofs_running << " nbnd in processor " << i << " is " << mpi_times << std::endl; + if (myid==i) + { + for (j=0; j1); +} +*/ +MPI_Barrier(comm); + + // mohan add 2010-09-10 + // output the wave function if required. + // this is a bad position to output wave functions. + // but it works! + if(this->out_lowf) + { + // read is in ../src_algorithms/wf_local.cpp + std::stringstream ss; + ss << GlobalV::global_out_dir << "LOWF_GAMMA_S" << GlobalV::CURRENT_SPIN+1 << ".dat"; + // mohan add 2012-04-03, because we need the occupations for the + // first iteration. + Occupy::calculate_weights(); + WF_Local::write_lowf( ss.str(), ctot );//mohan add 2010-09-09 + } + + // mohan add 2010-09-10 + // distribution of local wave functions + // to each processor. + // only used for GlobalV::GAMMA_ONLY_LOCAL + //WF_Local::distri_lowf( ctot, wfc); + + + // clean staff. + if(myid==0) + { + for (int i=0; i Date: Sun, 27 Feb 2022 18:09:04 +0800 Subject: [PATCH 44/52] remove cal_nnr by merge it with Record_adf::for_2d --- source/src_lcao/FORCE_STRESS.cpp | 6 +- source/src_lcao/FORCE_STRESS.h | 7 +- source/src_lcao/FORCE_k.cpp | 38 +++--- source/src_lcao/FORCE_k.h | 6 +- source/src_lcao/LCAO_nnr.cpp | 148 ------------------------ source/src_lcao/LOOP_elec.cpp | 15 +-- source/src_lcao/LOOP_elec.h | 3 +- source/src_lcao/LOOP_ions.cpp | 20 ++-- source/src_lcao/LOOP_ions.h | 2 +- source/src_lcao/record_adj.cpp | 90 ++++++++------ source/src_lcao/record_adj.h | 2 +- source/src_lcao/run_md_lcao.cpp | 9 +- source/src_parallel/parallel_orbitals.h | 14 +-- 13 files changed, 114 insertions(+), 246 deletions(-) diff --git a/source/src_lcao/FORCE_STRESS.cpp b/source/src_lcao/FORCE_STRESS.cpp index 6b432d481bd..e7be8ea6dd1 100644 --- a/source/src_lcao/FORCE_STRESS.cpp +++ b/source/src_lcao/FORCE_STRESS.cpp @@ -14,8 +14,9 @@ double Force_Stress_LCAO::force_invalid_threshold_ev = 0.00; double Force_Stress_LCAO::output_acc = 1.0e-8; -Force_Stress_LCAO::Force_Stress_LCAO (){} -Force_Stress_LCAO::~Force_Stress_LCAO (){} +Force_Stress_LCAO::Force_Stress_LCAO(Record_adj& ra) : + RA(&ra){} +Force_Stress_LCAO::~Force_Stress_LCAO() {} #include "../src_pw/efield.h" void Force_Stress_LCAO::getForceStress( @@ -749,6 +750,7 @@ void Force_Stress_LCAO::calForceStressIntegralPart( flk.ftable_k( isforce, isstress, + *this->RA, lowf.wfc_k, loc, foverlap, diff --git a/source/src_lcao/FORCE_STRESS.h b/source/src_lcao/FORCE_STRESS.h index 1c6190d60ad..d4bc36f4218 100644 --- a/source/src_lcao/FORCE_STRESS.h +++ b/source/src_lcao/FORCE_STRESS.h @@ -22,11 +22,12 @@ class Force_Stress_LCAO public : - Force_Stress_LCAO (); + Force_Stress_LCAO (Record_adj &ra); ~Force_Stress_LCAO (); - private: - +private: + + Record_adj* RA; Force_LCAO_k flk; // Force_LCAO_gamma flg; Stress_Func sc_pw; diff --git a/source/src_lcao/FORCE_k.cpp b/source/src_lcao/FORCE_k.cpp index a62ce97802b..3c41072b157 100644 --- a/source/src_lcao/FORCE_k.cpp +++ b/source/src_lcao/FORCE_k.cpp @@ -22,6 +22,7 @@ Force_LCAO_k::~Force_LCAO_k () void Force_LCAO_k::ftable_k ( const bool isforce, const bool isstress, + Record_adj &ra, std::vector& wfc_k, Local_Orbital_Charge &loc, ModuleBase::matrix& foverlap, @@ -49,7 +50,7 @@ void Force_LCAO_k::ftable_k ( // calculate the energy density matrix // and the force related to overlap matrix and energy density matrix. - this->cal_foverlap_k(isforce, isstress, wfc_k, loc, foverlap, soverlap); + this->cal_foverlap_k(isforce, isstress, ra, wfc_k, loc, foverlap, soverlap); // calculate the density matrix double** dm2d = new double*[GlobalV::NSPIN]; @@ -60,11 +61,9 @@ void Force_LCAO_k::ftable_k ( } ModuleBase::Memory::record ("Force_LCAO_k", "dm2d", GlobalV::NSPIN*pv->nnr, "double"); - Record_adj RA; - RA.for_2d(*pv); - loc.cal_dm_R(loc.dm_k, RA, dm2d); + loc.cal_dm_R(loc.dm_k, ra, dm2d); - this->cal_ftvnl_dphi_k(dm2d, isforce, isstress, ftvnl_dphi, stvnl_dphi); + this->cal_ftvnl_dphi_k(dm2d, isforce, isstress, ra, ftvnl_dphi, stvnl_dphi); // --------------------------------------- @@ -237,6 +236,7 @@ void Force_LCAO_k::finish_k(void) void Force_LCAO_k::cal_foverlap_k( const bool isforce, const bool isstress, + Record_adj &ra, std::vector& wfc_k, Local_Orbital_Charge &loc, ModuleBase::matrix& foverlap, @@ -255,9 +255,6 @@ void Force_LCAO_k::cal_foverlap_k( edm2d[is] = new double[pv->nnr]; ModuleBase::GlobalFunc::ZEROS(edm2d[is], pv->nnr); } - - Record_adj RA; - RA.for_2d(*pv); //-------------------------------------------- // calculate the energy density matrix here. @@ -279,7 +276,7 @@ void Force_LCAO_k::cal_foverlap_k( wfc_k, edm_k); loc.cal_dm_R(edm_k, - RA, edm2d); + ra, edm2d); ModuleBase::timer::tick("Force_LCAO_k", "cal_edm_2d"); //-------------------------------------------- @@ -298,10 +295,10 @@ void Force_LCAO_k::cal_foverlap_k( for(int I1=0; I1na; ++I1) { const int start1 = GlobalC::ucell.itiaiw2iwt(T1,I1,0); - for (int cb = 0; cb < RA.na_each[iat]; ++cb) + for (int cb = 0; cb < ra.na_each[iat]; ++cb) { - const int T2 = RA.info[iat][cb][3]; - const int I2 = RA.info[iat][cb][4]; + const int T2 = ra.info[iat][cb][3]; + const int I2 = ra.info[iat][cb][4]; const int start2 = GlobalC::ucell.itiaiw2iwt(T2, I2, 0); Atom* atom2 = &GlobalC::ucell.atoms[T2]; @@ -374,7 +371,6 @@ void Force_LCAO_k::cal_foverlap_k( } delete[] edm2d; - RA.delete_grid();//xiaohui add 2015-02-04 ModuleBase::timer::tick("Force_LCAO_k","cal_foverlap_k"); return; } @@ -382,8 +378,9 @@ void Force_LCAO_k::cal_foverlap_k( void Force_LCAO_k::cal_ftvnl_dphi_k( double** dm2d, const bool isforce, - const bool isstress, - ModuleBase::matrix& ftvnl_dphi, + const bool isstress, + Record_adj &ra, + ModuleBase::matrix& ftvnl_dphi, ModuleBase::matrix& stvnl_dphi) { ModuleBase::TITLE("Force_LCAO_k","cal_ftvnl_dphi"); @@ -393,8 +390,6 @@ void Force_LCAO_k::cal_ftvnl_dphi_k( // get the adjacent atom's information. // GlobalV::ofs_running << " calculate the ftvnl_dphi_k force" << std::endl; - Record_adj RA; - RA.for_2d(*this->UHM->LM->ParaV); int irr = 0; for(int T1=0; T1UHM->LM->ParaV); + RA.for_2d(*this->UHM->LM->ParaV, GlobalV::GAMMA_ONLY_LOCAL); double *test; test = new double[GlobalV::NLOCAL * GlobalV::NLOCAL]; diff --git a/source/src_lcao/FORCE_k.h b/source/src_lcao/FORCE_k.h index 04297f8b46b..57a902d6881 100644 --- a/source/src_lcao/FORCE_k.h +++ b/source/src_lcao/FORCE_k.h @@ -27,6 +27,7 @@ class Force_LCAO_k : public Force_LCAO_gamma void ftable_k ( const bool isforce, const bool isstress, + Record_adj &ra, std::vector& wfc_k, Local_Orbital_Charge &loc, ModuleBase::matrix& foverlap, @@ -50,10 +51,11 @@ class Force_LCAO_k : public Force_LCAO_gamma void finish_k(void); // calculate the force due to < dphi | beta > < beta | phi > - void cal_ftvnl_dphi_k(double** dm2d, const bool isforce, const bool isstress, ModuleBase::matrix& ftvnl_dphi, ModuleBase::matrix& stvnl_dphi); + void cal_ftvnl_dphi_k(double** dm2d, const bool isforce, const bool isstress, Record_adj& ra, + ModuleBase::matrix& ftvnl_dphi, ModuleBase::matrix& stvnl_dphi); // calculate the overlap force - void cal_foverlap_k(const bool isforce, const bool isstress, std::vector& wfc_k, + void cal_foverlap_k(const bool isforce, const bool isstress, Record_adj &ra, std::vector& wfc_k, Local_Orbital_Charge& loc, ModuleBase::matrix& foverlap, ModuleBase::matrix& soverlap); // calculate the force due to < phi | Vlocal | dphi > diff --git a/source/src_lcao/LCAO_nnr.cpp b/source/src_lcao/LCAO_nnr.cpp index 222bcba181c..6561cb7d4e7 100644 --- a/source/src_lcao/LCAO_nnr.cpp +++ b/source/src_lcao/LCAO_nnr.cpp @@ -5,154 +5,6 @@ #ifdef __DEEPKS #include "../module_deepks/LCAO_deepks.h" #endif -//---------------------------- -// define a global class obj. -//---------------------------- - -// be called in LOOP_ions.cpp -void Parallel_Orbitals::cal_nnr() -{ - ModuleBase::TITLE("LCAO_nnr","cal_nnr"); - - delete[] nlocdim; - delete[] nlocstart; - nlocdim = new int[GlobalC::ucell.nat]; - nlocstart = new int[GlobalC::ucell.nat]; - ModuleBase::GlobalFunc::ZEROS(nlocdim, GlobalC::ucell.nat); - ModuleBase::GlobalFunc::ZEROS(nlocstart, GlobalC::ucell.nat); - - this->nnr = 0; - int start = 0; - //int ind1 = 0; - int iat = 0; - - // (1) find the adjacent atoms of atom[T1,I1]; - ModuleBase::Vector3 tau1; - ModuleBase::Vector3 tau2; - ModuleBase::Vector3 dtau; - ModuleBase::Vector3 tau0; - ModuleBase::Vector3 dtau1; - ModuleBase::Vector3 dtau2; - - for (int T1 = 0; T1 < GlobalC::ucell.ntype; T1++) - { - for (int I1 = 0; I1 < GlobalC::ucell.atoms[T1].na; I1++) - { - tau1 = GlobalC::ucell.atoms[T1].tau[I1]; - //GlobalC::GridD.Find_atom( tau1 ); - GlobalC::GridD.Find_atom(GlobalC::ucell, tau1 ,T1, I1); - const int start1 = GlobalC::ucell.itiaiw2iwt(T1, I1, 0); - this->nlocstart[iat] = nnr; - int nw1 = GlobalC::ucell.atoms[T1].nw * GlobalV::NPOL; - - // (2) search among all adjacent atoms. - for (int ad = 0; ad < GlobalC::GridD.getAdjacentNum()+1; ad++) - { - const int T2 = GlobalC::GridD.getType(ad); - const int I2 = GlobalC::GridD.getNatom(ad); - //const int iat2 = GlobalC::ucell.itia2iat(T2, I2); - const int start2 = GlobalC::ucell.itiaiw2iwt(T2, I2, 0); - int nw2 = GlobalC::ucell.atoms[T2].nw * GlobalV::NPOL; - - tau2 = GlobalC::GridD.getAdjacentTau(ad); - - dtau = tau2 - tau1; - double distance = dtau.norm() * GlobalC::ucell.lat0; - double rcut = GlobalC::ORB.Phi[T1].getRcut() + GlobalC::ORB.Phi[T2].getRcut(); - - if(distance < rcut) - { - //-------------------------------------------------- - // calculate how many matrix elements are in - // this processor. - for(int ii=0; iitrace_loc_row[iw1_all]; - if(mu<0)continue; - - for(int jj=0; jjtrace_loc_col[iw2_all]; - if(nu<0)continue; - - // orbital numbers for this atom (iat), - // seperated by atoms in different cells. - this->nlocdim[iat]++; - - ++nnr; - }// end jj - }// end ii - }//end distance - // there is another possibility that i and j are adjacent atoms. - // which is that are adjacents while are also - // adjacents, these considerations are only considered in k-point - // algorithm, - // mohan fix bug 2012-07-03 - else if(distance >= rcut) - { - for (int ad0 = 0; ad0 < GlobalC::GridD.getAdjacentNum()+1; ++ad0) - { - const int T0 = GlobalC::GridD.getType(ad0); - const int I0 = GlobalC::GridD.getNatom(ad0); - //const int iat0 = GlobalC::ucell.itia2iat(T0, I0); - //const int start0 = GlobalC::ucell.itiaiw2iwt(T0, I0, 0); - - tau0 = GlobalC::GridD.getAdjacentTau(ad0); - dtau1 = tau0 - tau1; - double distance1 = dtau1.norm() * GlobalC::ucell.lat0; - double rcut1 = GlobalC::ORB.Phi[T1].getRcut() + GlobalC::ucell.infoNL.Beta[T0].get_rcut_max(); - - dtau2 = tau0 - tau2; - double distance2 = dtau2.norm() * GlobalC::ucell.lat0; - double rcut2 = GlobalC::ORB.Phi[T2].getRcut() + GlobalC::ucell.infoNL.Beta[T0].get_rcut_max(); - - if( distance1 < rcut1 && distance2 < rcut2 ) - { - for(int ii=0; iitrace_loc_row[iw1_all]; - if(mu<0)continue; - - for(int jj=0; jjtrace_loc_col[iw2_all]; - if(nu<0)continue; - - // orbital numbers for this atom (iat), - // seperated by atoms in different cells. - this->nlocdim[iat]++; - - ++nnr; - } - } - break; - } // dis1, dis2 - }//ad0 - } - }// end ad - - //start position of atom[T1,I1] - start += nw1; - ++iat; - }// end I1 - } // end T1 - - //xiaohui add 'GlobalV::OUT_LEVEL' line, 2015-09-16 - if(GlobalV::OUT_LEVEL != "m") ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"nnr",nnr); -// for(int iat=0; iatUHM->LM->ParaV; + ra.for_2d(*pv, GlobalV::GAMMA_ONLY_LOCAL); if(!GlobalV::GAMMA_ONLY_LOCAL) { - // For each atom, calculate the adjacent atoms in different cells - // and allocate the space for H(R) and S(R). - Parallel_Orbitals* pv = this->UHM->LM->ParaV; - pv->cal_nnr(); this->UHM->LM->allocate_HS_R(pv->nnr); #ifdef __DEEPKS GlobalC::ld.allocate_V_deltaR(pv->nnr); diff --git a/source/src_lcao/LOOP_elec.h b/source/src_lcao/LOOP_elec.h index 2c4ef229808..dbf1538f3bc 100644 --- a/source/src_lcao/LOOP_elec.h +++ b/source/src_lcao/LOOP_elec.h @@ -18,6 +18,7 @@ class LOOP_elec // mohan add 2021-02-09 void solve_elec_stru(const int& istep, + Record_adj &ra, Local_Orbital_Charge& loc, Local_Orbital_wfc& low, LCAO_Hamilt& uhm_in); @@ -25,7 +26,7 @@ class LOOP_elec private: // set matrix and grid integral - void set_matrix_grid(void); + void set_matrix_grid(Record_adj &ra); void before_solver(const int& istep, Local_Orbital_Charge& loc, diff --git a/source/src_lcao/LOOP_ions.cpp b/source/src_lcao/LOOP_ions.cpp index 86d00928737..b38c54f52dc 100644 --- a/source/src_lcao/LOOP_ions.cpp +++ b/source/src_lcao/LOOP_ions.cpp @@ -149,10 +149,10 @@ void LOOP_ions::opt_ions() GlobalC::en.evdw = vdwd3.get_energy(); } - + Record_adj RA; // solve electronic structures in terms of LCAO // mohan add 2021-02-09 - LOE.solve_elec_stru(this->istep, this->LOC, this->LOWF, this->UHM); + LOE.solve_elec_stru(this->istep, RA, this->LOC, this->LOWF, this->UHM); time_t eend = time(NULL); @@ -234,10 +234,10 @@ void LOOP_ions::opt_ions() time_t fstart = time(NULL); if (GlobalV::CALCULATION=="scf" || GlobalV::CALCULATION=="relax" || GlobalV::CALCULATION=="cell-relax") { - stop = this->force_stress(istep, force_step, stress_step); + stop = this->force_stress(istep, force_step, stress_step, RA); } time_t fend = time(NULL); - + RA.delete_grid(); // PLEASE move the details of CE to other places // mohan add 2021-03-25 //xiaohui add 2014-07-07, for second-order extrapolation @@ -299,7 +299,8 @@ void LOOP_ions::opt_ions() bool LOOP_ions::force_stress( const int &istep, int &force_step, - int &stress_step) + int& stress_step, + Record_adj &ra) { ModuleBase::TITLE("LOOP_ions","force_stress"); @@ -313,7 +314,7 @@ bool LOOP_ions::force_stress( ModuleBase::matrix fcs; // set stress matrix ModuleBase::matrix scs; - Force_Stress_LCAO FSL; + Force_Stress_LCAO FSL(ra); FSL.getForceStress(GlobalV::FORCE, GlobalV::STRESS, GlobalV::TEST_FORCE, GlobalV::TEST_STRESS, this->LOC, this->LOWF, this->UHM, fcs, scs); @@ -535,12 +536,13 @@ void LOOP_ions::final_scf(void) GlobalC::pw.nbxx, GlobalC::pw.nbzp_start, GlobalC::pw.nbzp); // (2) If k point is used here, allocate HlocR after atom_arrange. - if(!GlobalV::GAMMA_ONLY_LOCAL) + Parallel_Orbitals* pv = this->UHM.LM->ParaV; + Record_adj RA; + RA.for_2d(*pv, GlobalV::GAMMA_ONLY_LOCAL); + if (!GlobalV::GAMMA_ONLY_LOCAL) { // For each atom, calculate the adjacent atoms in different cells // and allocate the space for H(R) and S(R). - Parallel_Orbitals* pv = this->UHM.LM->ParaV; - pv->cal_nnr(); this->UHM.LM->allocate_HS_R(pv->nnr); #ifdef __DEEPKS GlobalC::ld.allocate_V_deltaR(pv->nnr); diff --git a/source/src_lcao/LOOP_ions.h b/source/src_lcao/LOOP_ions.h index 3e548b7ee47..d30b4c5d0af 100644 --- a/source/src_lcao/LOOP_ions.h +++ b/source/src_lcao/LOOP_ions.h @@ -46,7 +46,7 @@ class LOOP_ions // the renew of structure factors, etc. should be ran in other places // the 'IMM' and 'LCM' objects should be passed to force_stress() via parameters list // mohan note 2021-03-23 - bool force_stress(const int &istep, int &force_step, int &stress_step); + bool force_stress(const int &istep, int &force_step, int &stress_step, Record_adj &ra); int istep; diff --git a/source/src_lcao/record_adj.cpp b/source/src_lcao/record_adj.cpp index 966383793d9..ff6aecb1765 100644 --- a/source/src_lcao/record_adj.cpp +++ b/source/src_lcao/record_adj.cpp @@ -26,14 +26,26 @@ void Record_adj::delete_grid(void) //-------------------------------------------- // This will record the orbitals according to // HPSEPS's 2D block division. +// If multi-k, calculate nnr at the same time. +// be called only once in an ion-step. //-------------------------------------------- -void Record_adj::for_2d(const Parallel_Orbitals &pv) +void Record_adj::for_2d(Parallel_Orbitals &pv, bool gamma_only) { ModuleBase::TITLE("Record_adj","for_2d"); ModuleBase::timer::tick("Record_adj","for_2d"); - assert(GlobalC::ucell.nat>0); - + assert(GlobalC::ucell.nat > 0); + if (!gamma_only) + { + delete[] pv.nlocdim; + delete[] pv.nlocstart; + pv.nlocdim = new int[GlobalC::ucell.nat]; + pv.nlocstart = new int[GlobalC::ucell.nat]; + ModuleBase::GlobalFunc::ZEROS(pv.nlocdim, GlobalC::ucell.nat); + ModuleBase::GlobalFunc::ZEROS(pv.nlocstart, GlobalC::ucell.nat); + pv.nnr = 0; + } + // (1) find the adjacent atoms of atom[T1,I1]; ModuleBase::Vector3 tau1, tau2, dtau; ModuleBase::Vector3 dtau1, dtau2, tau0; @@ -44,8 +56,6 @@ void Record_adj::for_2d(const Parallel_Orbitals &pv) this->na_each = new int[na_proc]; ModuleBase::GlobalFunc::ZEROS(na_each, na_proc); int iat = 0; - int irr = 0; - // std::cout << " in for_2d" << std::endl; @@ -58,8 +68,9 @@ void Record_adj::for_2d(const Parallel_Orbitals &pv) //GlobalC::GridD.Find_atom( tau1 ); GlobalC::GridD.Find_atom(GlobalC::ucell, tau1 ,T1, I1); const int start1 = GlobalC::ucell.itiaiw2iwt(T1, I1, 0); - - // (2) search among all adjacent atoms. + if(!gamma_only) pv.nlocstart[iat] = pv.nnr; + + // (2) search among all adjacent atoms. for (int ad = 0; ad < GlobalC::GridD.getAdjacentNum()+1; ++ad) { const int T2 = GlobalC::GridD.getType(ad); @@ -71,8 +82,12 @@ void Record_adj::for_2d(const Parallel_Orbitals &pv) double rcut = GlobalC::ORB.Phi[T1].getRcut() + GlobalC::ORB.Phi[T2].getRcut(); bool is_adj = false; - if( distance < rcut) is_adj = true; - else if( distance >= rcut) + if (distance < rcut) is_adj = true; + // there is another possibility that i and j are adjacent atoms. + // which is that are adjacents while are also + // adjacents, these considerations are only considered in k-point + // algorithm, + else if (distance >= rcut) { for (int ad0 = 0; ad0 < GlobalC::GridD.getAdjacentNum()+1; ++ad0) { @@ -98,39 +113,37 @@ void Record_adj::for_2d(const Parallel_Orbitals &pv) } } - - if(is_adj) { - ++na_each[iat]; - - for(int ii=0; iinw * GlobalV::NPOL; ++ii) - { - // the index of orbitals in this processor - const int iw1_all = start1 + ii; - const int mu = pv.trace_loc_row[iw1_all]; - if(mu<0)continue; + ++na_each[iat]; + if (!gamma_only) + { + for(int ii=0; iinw * GlobalV::NPOL; ++ii) + { + // the index of orbitals in this processor + const int iw1_all = start1 + ii; + const int mu = pv.trace_loc_row[iw1_all]; + if(mu<0)continue; - for(int jj=0; jj 10|| I2<0) exit(-1); + ++cb; } }//end ad // GlobalV::ofs_running << " nadj = " << cb << std::endl; diff --git a/source/src_lcao/record_adj.h b/source/src_lcao/record_adj.h index 6290f7ada07..5f93f022194 100644 --- a/source/src_lcao/record_adj.h +++ b/source/src_lcao/record_adj.h @@ -19,7 +19,7 @@ class Record_adj // This will record the orbitals according to // HPSEPS's 2D block division. //-------------------------------------------- - void for_2d(const Parallel_Orbitals &pv); + void for_2d(Parallel_Orbitals &pv, bool gamma_only); //-------------------------------------------- // This will record the orbitals according to diff --git a/source/src_lcao/run_md_lcao.cpp b/source/src_lcao/run_md_lcao.cpp index 21ca72d30c7..fdc49e85a8f 100644 --- a/source/src_lcao/run_md_lcao.cpp +++ b/source/src_lcao/run_md_lcao.cpp @@ -254,16 +254,19 @@ void Run_MD_LCAO::md_force_virial( // mohan add 2021-02-09 LCAO_Hamilt UHM_md; UHM_md.genH.LM = UHM_md.LM = &this->LM_md; + + Record_adj RA_md; + LOOP_elec LOE; - LOE.solve_elec_stru(istep + 1, LOC_md, LOWF_md, UHM_md); + LOE.solve_elec_stru(istep + 1, RA_md, LOC_md, LOWF_md, UHM_md); //to call the force of each atom ModuleBase::matrix fcs;//temp force matrix - Force_Stress_LCAO FSL; + Force_Stress_LCAO FSL(RA_md); FSL.getForceStress(GlobalV::FORCE, GlobalV::STRESS, GlobalV::TEST_FORCE, GlobalV::TEST_STRESS, LOC_md, LOWF_md, UHM_md, fcs, virial); - + RA_md.delete_grid(); for(int ion=0; ion 2D block distribution; - //--------------------------------------- - int nnr; + ///--------------------------------------- + /// number of elements(basis-pairs) in this processon + /// on all adjacent atoms-pairs(2D division) + ///--------------------------------------- + int nnr; int *nlocdim; int *nlocstart; @@ -72,10 +72,6 @@ struct Parallel_Orbitals /// (check whether local-index > 0 ) bool in_this_processor(const int& iw1_all, const int& iw2_all) const; - /// number of elements(basis-pairs) in this processon - /// on all adjacent atoms-pairs(2D division) - void cal_nnr(); - }; From b785375fe3cb556a6d097acb27704428208f8c5e Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Mon, 28 Feb 2022 09:39:46 +0800 Subject: [PATCH 45/52] remove GlobalV in ORB_control --- source/module_orbital/ORB_control.cpp | 221 ++++++++++++---------- source/module_orbital/ORB_control.h | 55 +++++- source/run_lcao.cpp | 10 +- source/src_parallel/parallel_orbitals.cpp | 136 ++++++------- source/src_parallel/parallel_orbitals.h | 6 +- 5 files changed, 254 insertions(+), 174 deletions(-) diff --git a/source/module_orbital/ORB_control.cpp b/source/module_orbital/ORB_control.cpp index fb2524ed2ce..51cb4fb0588 100644 --- a/source/module_orbital/ORB_control.cpp +++ b/source/module_orbital/ORB_control.cpp @@ -8,9 +8,35 @@ //#include "build_st_pw.h" +ORB_control::ORB_control( + const bool& gamma_only_in, + const int& nlocal_in, + const int& nbands_in, + const int& nspin_in, + const int& dsize_in, + const int& nb2d_in, + const int& dcolor_in, + const int& drank_in, + const int& myrank_in, + const std::string& calculation_in, + const std::string& ks_solver_in) : + gamma_only(gamma_only_in), + nlocal(nlocal_in), + nbands(nbands_in), + nspin(nspin_in), + dsize(dsize_in), + nb2d(nb2d_in), + dcolor(dcolor_in), + drank(drank_in), + myrank(myrank_in), + calculation(calculation_in), + ks_solver(ks_solver_in) +{ + this->ParaV.nspin = nspin_in; +} + ORB_control::ORB_control() {} - ORB_control::~ORB_control() {} @@ -83,7 +109,7 @@ void ORB_control::set_orb_tables( #ifdef __NORMAL #else - if(GlobalV::CALCULATION=="test") + if(calculation=="test") { ModuleBase::timer::tick("ORB_control","set_orb_tables"); return; @@ -131,20 +157,21 @@ void ORB_control::clear_after_ions( } -void ORB_control::setup_2d_division(void) +void ORB_control::setup_2d_division(std::ofstream& ofs_running, + std::ofstream& ofs_warning) { ModuleBase::TITLE("ORB_control","setup_2d_division"); - GlobalV::ofs_running << "\n SETUP THE DIVISION OF H/S MATRIX" << std::endl; + ofs_running << "\n SETUP THE DIVISION OF H/S MATRIX" << std::endl; // (1) calculate nrow, ncol, nloc. - if (GlobalV::KS_SOLVER=="genelpa" || GlobalV::KS_SOLVER=="hpseps" || GlobalV::KS_SOLVER=="scalpack" - || GlobalV::KS_SOLVER=="selinv" || GlobalV::KS_SOLVER=="scalapack_gvx") + if (ks_solver=="genelpa" || ks_solver=="hpseps" || ks_solver=="scalpack" + || ks_solver=="selinv" || ks_solver=="scalapack_gvx") { - GlobalV::ofs_running << " divide the H&S matrix using 2D block algorithms." << std::endl; + ofs_running << " divide the H&S matrix using 2D block algorithms." << std::endl; #ifdef __MPI // storage form of H and S matrices on each processor // is determined in 'divide_HS_2d' subroutine - this->divide_HS_2d(DIAG_WORLD); + this->divide_HS_2d(DIAG_WORLD, ofs_running, ofs_warning); #else ModuleBase::WARNING_QUIT("LCAO_Matrix::init","diago method is not ready."); #endif @@ -152,123 +179,124 @@ void ORB_control::setup_2d_division(void) else { // the full matrix - this->ParaV.nloc = GlobalV::NLOCAL * GlobalV::NLOCAL; + this->ParaV.nloc = nlocal * nlocal; } // (2) set the trace, then we can calculate the nnr. // for 2d: calculate po.nloc first, then trace_loc_row and trace_loc_col // for O(N): calculate the three together. - this->set_trace(); + this->set_trace(ofs_running); } -void ORB_control::set_parameters(void) +void ORB_control::set_parameters(std::ofstream& ofs_running, + std::ofstream& ofs_warning) { ModuleBase::TITLE("ORB_control","set_parameters"); Parallel_Orbitals* pv = &this->ParaV; // set loc_size - if(GlobalV::GAMMA_ONLY_LOCAL)//xiaohui add 2014-12-21 + if(gamma_only)//xiaohui add 2014-12-21 { - pv->loc_size=GlobalV::NBANDS/GlobalV::DSIZE; + pv->loc_size=nbands/dsize; // mohan add 2012-03-29 if(pv->loc_size==0) { - GlobalV::ofs_warning << " loc_size=0" << " in proc " << GlobalV::MY_RANK+1 << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < GlobalV::DSIZE"); + ofs_warning << " loc_size=0" << " in proc " << myrank+1 << std::endl; + ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < dsize"); } - if (GlobalV::DRANKloc_size+=1; - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"local size",pv->loc_size); + if (drankloc_size+=1; + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"local size",pv->loc_size); // set loc_sizes delete[] pv->loc_sizes; - pv->loc_sizes = new int[GlobalV::DSIZE]; - ModuleBase::GlobalFunc::ZEROS(pv->loc_sizes, GlobalV::DSIZE); + pv->loc_sizes = new int[dsize]; + ModuleBase::GlobalFunc::ZEROS(pv->loc_sizes, dsize); pv->lastband_in_proc = 0; pv->lastband_number = 0; int count_bands = 0; - for (int i=0; iloc_sizes[i]=GlobalV::NBANDS/GlobalV::DSIZE+1; + pv->loc_sizes[i]=nbands/dsize+1; } else { - pv->loc_sizes[i]=GlobalV::NBANDS/GlobalV::DSIZE; + pv->loc_sizes[i]=nbands/dsize; } count_bands += pv->loc_sizes[i]; - if (count_bands >= GlobalV::NBANDS) + if (count_bands >= nbands) { pv->lastband_in_proc = i; - pv->lastband_number = GlobalV::NBANDS - (count_bands - pv->loc_sizes[i]); + pv->lastband_number = nbands - (count_bands - pv->loc_sizes[i]); break; } } } else { - pv->loc_size=GlobalV::NLOCAL/GlobalV::DSIZE; + pv->loc_size=nlocal/dsize; // mohan add 2012-03-29 if(pv->loc_size==0) { - GlobalV::ofs_warning << " loc_size=0" << " in proc " << GlobalV::MY_RANK+1 << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < GlobalV::DSIZE"); + ofs_warning << " loc_size=0" << " in proc " << myrank+1 << std::endl; + ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < DSIZE"); } - if (GlobalV::DRANKloc_size += 1; } - if(pv->testpb) ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"local size",pv->loc_size); + if(pv->testpb) ModuleBase::GlobalFunc::OUT(ofs_running,"local size",pv->loc_size); // set loc_sizes delete[] pv->loc_sizes; - pv->loc_sizes = new int[GlobalV::DSIZE]; - ModuleBase::GlobalFunc::ZEROS(pv->loc_sizes, GlobalV::DSIZE); + pv->loc_sizes = new int[dsize]; + ModuleBase::GlobalFunc::ZEROS(pv->loc_sizes, dsize); pv->lastband_in_proc = 0; pv->lastband_number = 0; int count_bands = 0; - for (int i=0; iloc_sizes[i]=GlobalV::NLOCAL/GlobalV::DSIZE+1; + pv->loc_sizes[i]=nlocal/dsize+1; } else { - pv->loc_sizes[i]=GlobalV::NLOCAL/GlobalV::DSIZE; + pv->loc_sizes[i]=nlocal/dsize; } count_bands += pv->loc_sizes[i]; - if (count_bands >= GlobalV::NBANDS) + if (count_bands >= nbands) { pv->lastband_in_proc = i; - pv->lastband_number = GlobalV::NBANDS - (count_bands - pv->loc_sizes[i]); + pv->lastband_number = nbands - (count_bands - pv->loc_sizes[i]); break; } } }//xiaohui add 2014-12-21 - if (GlobalV::KS_SOLVER=="hpseps") //LiuXh add 2021-09-06, clear memory, Z_LOC only used in hpseps solver + if (ks_solver=="hpseps") //LiuXh add 2021-09-06, clear memory, Z_LOC only used in hpseps solver { - pv->Z_LOC = new double*[GlobalV::NSPIN]; - for(int is=0; isZ_LOC = new double*[nspin]; + for(int is=0; isZ_LOC[is] = new double[pv->loc_size * GlobalV::NLOCAL]; - ModuleBase::GlobalFunc::ZEROS(pv->Z_LOC[is], pv->loc_size * GlobalV::NLOCAL); + pv->Z_LOC[is] = new double[pv->loc_size * nlocal]; + ModuleBase::GlobalFunc::ZEROS(pv->Z_LOC[is], pv->loc_size * nlocal); } pv->alloc_Z_LOC = true;//xiaohui add 2014-12-22 } - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"lastband_in_proc", pv->lastband_in_proc); - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"lastband_number", pv->lastband_number); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"lastband_in_proc", pv->lastband_in_proc); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"lastband_number", pv->lastband_number); return; } @@ -276,7 +304,8 @@ void ORB_control::set_parameters(void) #ifdef __MPI // creat the 'comm_2D' stratege. -void ORB_control::mpi_creat_cart(MPI_Comm *comm_2D, int prow, int pcol) +void ORB_control::mpi_creat_cart(MPI_Comm* comm_2D, + int prow, int pcol, std::ofstream& ofs_running) { ModuleBase::TITLE("ORB_control","mpi_creat_cart"); // the matrix is divided as ( dim[0] * dim[1] ) @@ -286,7 +315,7 @@ void ORB_control::mpi_creat_cart(MPI_Comm *comm_2D, int prow, int pcol) dim[0]=prow; dim[1]=pcol; - if(this->ParaV.testpb)GlobalV::ofs_running << " dim = " << dim[0] << " * " << dim[1] << std::endl; + if(this->ParaV.testpb) ofs_running << " dim = " << dim[0] << " * " << dim[1] << std::endl; MPI_Cart_create(DIAG_WORLD,2,dim,period,reorder,comm_2D); return; @@ -295,10 +324,12 @@ void ORB_control::mpi_creat_cart(MPI_Comm *comm_2D, int prow, int pcol) #ifdef __MPI void ORB_control::mat_2d(MPI_Comm vu, - const int &M_A, - const int &N_A, - const int &nb, - LocalMatrix &LM) + const int &M_A, + const int &N_A, + const int &nb, + LocalMatrix &LM, + std::ofstream& ofs_running, + std::ofstream& ofs_warning) { ModuleBase::TITLE("ORB_control", "mat_2d"); @@ -325,13 +356,13 @@ void ORB_control::mat_2d(MPI_Comm vu, block++; } - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Total Row Blocks Number",block); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Total Row Blocks Number",block); // mohan add 2010-09-12 if(dim[0]>block) { - GlobalV::ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; - GlobalV::ofs_warning << " but, the number of row blocks is " << block << std::endl; + ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; + ofs_warning << " but, the number of row blocks is " << block << std::endl; ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no row blocks, try a smaller 'nb2d' parameter."); } @@ -345,7 +376,7 @@ void ORB_control::mat_2d(MPI_Comm vu, LM.row_b++; } - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local Row Block Number",LM.row_b); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Local Row Block Number",LM.row_b); // (3) end_id indicates the last block belong to // which processor. @@ -358,7 +389,7 @@ void ORB_control::mat_2d(MPI_Comm vu, end_id=block%dim[0]-1; } - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Ending Row Block in processor",end_id); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Ending Row Block in processor",end_id); // (4) row_num : how many rows in this processors : // the one owns the last block is different. @@ -371,7 +402,7 @@ void ORB_control::mat_2d(MPI_Comm vu, LM.row_num=LM.row_b*nb; } - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local rows (including nb)",LM.row_num); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Local rows (including nb)",LM.row_num); // (5) row_set, it's a global index : // save explicitly : every row in this processor @@ -384,17 +415,17 @@ void ORB_control::mat_2d(MPI_Comm vu, for (k=0; ktestpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Total Col Blocks Number",block); if(dim[1]>block) { - GlobalV::ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; - GlobalV::ofs_warning << " but, the number of column blocks is " << block << std::endl; + ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; + ofs_warning << " but, the number of column blocks is " << block << std::endl; ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no column blocks."); } @@ -404,7 +435,7 @@ void ORB_control::mat_2d(MPI_Comm vu, LM.col_b++; } - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local Row Block Number",LM.col_b); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Local Row Block Number",LM.col_b); if (block%dim[1]==0) { @@ -415,7 +446,7 @@ void ORB_control::mat_2d(MPI_Comm vu, end_id=block%dim[1]-1; } - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Ending Row Block in processor",end_id); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Ending Row Block in processor",end_id); if (coord[1]==end_id) { @@ -426,7 +457,7 @@ void ORB_control::mat_2d(MPI_Comm vu, LM.col_num=LM.col_b*nb; } - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local columns (including nb)",LM.row_num); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Local columns (including nb)",LM.row_num); delete[] LM.col_set; LM.col_set = new int[LM.col_num]; @@ -450,8 +481,8 @@ void ORB_control::mat_2d(MPI_Comm vu, } if(dim[1]>block) { - GlobalV::ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; - GlobalV::ofs_warning << " but, the number of bands-row-block is " << block << std::endl; + ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; + ofs_warning << " but, the number of bands-row-block is " << block << std::endl; ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no bands-row-blocks."); } int col_b_bands = block / dim[1]; @@ -505,9 +536,9 @@ void ORB_control::data_distribution( int coord[2]; MPI_Cart_get(comm_2D,2,dim,period,coord); - if(pv->testpb) GlobalV::ofs_running << "\n dim = " << dim[0] << " * " << dim[1] << std::endl; - if(pv->testpb) GlobalV::ofs_running << " coord = ( " << coord[0] << " , " << coord[1] << ")." << std::endl; - if(pv->testpb) GlobalV::ofs_running << " n = " << n << std::endl; + if(pv->testpb) ofs_running << "\n dim = " << dim[0] << " * " << dim[1] << std::endl; + if(pv->testpb) ofs_running << " coord = ( " << coord[0] << " , " << coord[1] << ")." << std::endl; + if(pv->testpb) ofs_running << " n = " << n << std::endl; mpi_sub_col(comm_2D,&comm_col); mpi_sub_row(comm_2D,&comm_row); @@ -542,10 +573,10 @@ void ORB_control::data_distribution( for (int i=1; impi_creat_cart(&pv->comm_2D,dim[0],dim[1]); @@ -716,10 +747,10 @@ void ORB_control::readin( // call mat_2d this->mat_2d(pv->comm_2D, nlocal_tot,nlocal_tot,pv->nb,pv->MatrixInfo); - pv->loc_size=nlocal_tot/GlobalV::DSIZE; - if (GlobalV::DRANKloc_size=nlocal_tot/dsize; + if (drankdata_distribution(pv->comm_2D,fa,nlocal_tot,pv->nb,A,pv->MatrixInfo); - GlobalV::ofs_running << "\n Data distribution of S." << std::endl; + ofs_running << "\n Data distribution of S." << std::endl; this->data_distribution(pv->comm_2D,fb,nlocal_tot,pv->nb,B,pv->MatrixInfo); time1=MPI_Wtime(); @@ -741,15 +772,15 @@ void ORB_control::readin( char uplo = 'U'; pdgseps(pv->comm_2D,nlocal_tot,pv->nb,A,B,Z,eigen,pv->MatrixInfo,uplo,pv->loc_size,loc_pos); time2=MPI_Wtime(); - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"time1",time1); - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"time2",time2); + ModuleBase::GlobalFunc::OUT(ofs_running,"time1",time1); + ModuleBase::GlobalFunc::OUT(ofs_running,"time2",time2); //this->gath_eig(comm,n,eigvr,Z); - GlobalV::ofs_running << "\n " << std::setw(6) << "Band" << std::setw(25) << "Ry" << std::setw(25) << " eV" << std::endl; + ofs_running << "\n " << std::setw(6) << "Band" << std::setw(25) << "Ry" << std::setw(25) << " eV" << std::endl; for(int i=0; inspin; is++) { delete[] Z_LOC[is]; } @@ -59,70 +58,70 @@ bool Parallel_Orbitals::in_this_processor(const int &iw1_all, const int &iw2_all return true; } -void ORB_control::set_trace(void) +void ORB_control::set_trace(std::ofstream& ofs_running) { ModuleBase::TITLE("ORB_control","set_trace"); - assert(GlobalV::NLOCAL > 0); + assert(nlocal > 0); Parallel_Orbitals* pv = &this->ParaV; delete[] pv->trace_loc_row; delete[] pv->trace_loc_col; - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"trace_loc_row dimension",GlobalV::NLOCAL); - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"trace_loc_col dimension",GlobalV::NLOCAL); + ModuleBase::GlobalFunc::OUT(ofs_running,"trace_loc_row dimension",nlocal); + ModuleBase::GlobalFunc::OUT(ofs_running,"trace_loc_col dimension",nlocal); - pv->trace_loc_row = new int[GlobalV::NLOCAL]; - pv->trace_loc_col = new int[GlobalV::NLOCAL]; + pv->trace_loc_row = new int[nlocal]; + pv->trace_loc_col = new int[nlocal]; // mohan update 2011-04-07 - for(int i=0; itrace_loc_row[i] = -1; pv->trace_loc_col[i] = -1; } - ModuleBase::Memory::record("ORB_control","trace_loc_row",GlobalV::NLOCAL,"int"); - ModuleBase::Memory::record("ORB_control","trace_loc_col",GlobalV::NLOCAL,"int"); + ModuleBase::Memory::record("ORB_control","trace_loc_row",nlocal,"int"); + ModuleBase::Memory::record("ORB_control","trace_loc_col",nlocal,"int"); - if(GlobalV::KS_SOLVER=="lapack" - || GlobalV::KS_SOLVER=="cg" - || GlobalV::KS_SOLVER=="dav") //xiaohui add 2013-09-02 + if(ks_solver=="lapack" + || ks_solver=="cg" + || ks_solver=="dav") //xiaohui add 2013-09-02 { std::cout << " common settings for trace_loc_row and trace_loc_col " << std::endl; - for (int i=0; itrace_loc_row[i] = i; pv->trace_loc_col[i] = i; } - pv->nrow = GlobalV::NLOCAL; - pv->ncol = GlobalV::NLOCAL; + pv->nrow = nlocal; + pv->ncol = nlocal; } #ifdef __MPI - else if(GlobalV::KS_SOLVER=="scalpack" || GlobalV::KS_SOLVER=="genelpa" || GlobalV::KS_SOLVER=="hpseps" - || GlobalV::KS_SOLVER=="selinv" || GlobalV::KS_SOLVER=="scalapack_gvx") //xiaohui add 2013-09-02 + else if(ks_solver=="scalpack" || ks_solver=="genelpa" || ks_solver=="hpseps" + || ks_solver=="selinv" || ks_solver=="scalapack_gvx") //xiaohui add 2013-09-02 { - // GlobalV::ofs_running << " nrow=" << nrow << std::endl; + // ofs_running << " nrow=" << nrow << std::endl; for (int irow=0; irow< pv->nrow; irow++) { int global_row = pv->MatrixInfo.row_set[irow]; pv->trace_loc_row[global_row] = irow; - // GlobalV::ofs_running << " global_row=" << global_row + // ofs_running << " global_row=" << global_row // << " trace_loc_row=" << pv->trace_loc_row[global_row] << std::endl; } - // GlobalV::ofs_running << " ncol=" << ncol << std::endl; + // ofs_running << " ncol=" << ncol << std::endl; for (int icol=0; icol< pv->ncol; icol++) { int global_col = pv->MatrixInfo.col_set[icol]; pv->trace_loc_col[global_col] = icol; - // GlobalV::ofs_running << " global_col=" << global_col + // ofs_running << " global_col=" << global_col // << " trace_loc_col=" << pv->trace_loc_row[global_col] << std::endl; } } #endif else { - std::cout << " Parallel Orbial, GlobalV::DIAGO_TYPE = " << GlobalV::KS_SOLVER << std::endl; + std::cout << " Parallel Orbial, DIAGO_TYPE = " << ks_solver << std::endl; ModuleBase::WARNING_QUIT("ORB_control::set_trace","Check ks_solver."); } @@ -130,17 +129,17 @@ void ORB_control::set_trace(void) // print the trace for test. //--------------------------- /* - GlobalV::ofs_running << " " << std::setw(10) << "GlobalRow" << std::setw(10) << "LocalRow" << std::endl; - for(int i=0; itrace_loc_row[i] << std::endl; + ofs_running << " " << std::setw(10) << i << std::setw(10) << pv->trace_loc_row[i] << std::endl; } - GlobalV::ofs_running << " " << std::setw(10) << "GlobalCol" << std::setw(10) << "LocalCol" << std::endl; - for(int j=0; j0); - assert(GlobalV::DSIZE > 0); + assert(nlocal>0); + assert(dsize > 0); Parallel_Orbitals* pv = &this->ParaV; #ifdef __MPI DIAG_HPSEPS_WORLD=DIAG_WORLD; #endif - if(GlobalV::DCOLOR!=0) return; // mohan add 2012-01-13 + if(dcolor!=0) return; // mohan add 2012-01-13 // get the 2D index of computer. - pv->dim0 = (int)sqrt((double)GlobalV::DSIZE); //mohan update 2012/01/13 + pv->dim0 = (int)sqrt((double)dsize); //mohan update 2012/01/13 //while (GlobalV::NPROC_IN_POOL%dim0!=0) - while (GlobalV::DSIZE%pv->dim0!=0) + while (dsize%pv->dim0!=0) { pv->dim0 = pv->dim0 - 1; } assert(pv->dim0 > 0); - pv->dim1=GlobalV::DSIZE/pv->dim0; + pv->dim1=dsize/pv->dim0; - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"dim0",pv->dim0); - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"dim1",pv->dim1); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"dim0",pv->dim0); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"dim1",pv->dim1); #ifdef __MPI // mohan add 2011-04-16 - if(GlobalV::NB2D==0) + if(nb2d==0) { - if(GlobalV::NLOCAL>0) pv->nb = 1; - if(GlobalV::NLOCAL>500) pv->nb = 32; - if(GlobalV::NLOCAL>1000) pv->nb = 64; + if(nlocal>0) pv->nb = 1; + if(nlocal>500) pv->nb = 32; + if(nlocal>1000) pv->nb = 64; } - else if(GlobalV::NB2D>0) + else if(nb2d>0) { - pv->nb = GlobalV::NB2D; // mohan add 2010-06-28 + pv->nb = nb2d; // mohan add 2010-06-28 } - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"nb2d", pv->nb); + ModuleBase::GlobalFunc::OUT(ofs_running,"nb2d", pv->nb); - this->set_parameters(); + this->set_parameters(ofs_running, ofs_warning); // call mpi_creat_cart - this->mpi_creat_cart(&pv->comm_2D,pv->dim0,pv->dim1); + this->mpi_creat_cart(&pv->comm_2D,pv->dim0,pv->dim1, ofs_running); // call mat_2d - this->mat_2d(pv->comm_2D, GlobalV::NLOCAL, GlobalV::NBANDS, pv->nb, pv->MatrixInfo); + this->mat_2d(pv->comm_2D, nlocal, nbands, pv->nb, + pv->MatrixInfo, ofs_running, ofs_warning); // mohan add 2010-06-29 pv->nrow = pv->MatrixInfo.row_num; @@ -243,31 +244,32 @@ void ORB_control::divide_HS_2d pv->nloc = pv->MatrixInfo.col_num * pv->MatrixInfo.row_num; // init blacs context for genelpa - if(GlobalV::KS_SOLVER=="genelpa" || GlobalV::KS_SOLVER=="scalapack_gvx") + if(ks_solver=="genelpa" || ks_solver=="scalapack_gvx") { - pv->blacs_ctxt=cart2blacs(pv->comm_2D, pv->dim0, pv->dim1, GlobalV::NLOCAL, GlobalV::NBANDS, pv->nb, pv->nrow, pv->desc, pv->desc_wfc); + pv->blacs_ctxt = cart2blacs(pv->comm_2D, pv->dim0, pv->dim1, + nlocal, nbands, pv->nb, pv->nrow, pv->desc, pv->desc_wfc); } #else // single processor used. - pv->nb = GlobalV::NLOCAL; - pv->nrow = GlobalV::NLOCAL; - pv->ncol = GlobalV::NLOCAL; - pv->nloc = GlobalV::NLOCAL * GlobalV::NLOCAL; - this->set_parameters(); - pv->MatrixInfo.row_b = 1; - pv->MatrixInfo.row_num = GlobalV::NLOCAL; + pv->nb = nlocal; + pv->nrow = nlocal; + pv->ncol = nlocal; + pv->nloc = nlocal * nlocal; + this->set_parameters(ofs_running, ofs_warning); + pv->MatrixInfo.row_b = 1; + pv->MatrixInfo.row_num = nlocal; delete[] pv->MatrixInfo.row_set; - pv->MatrixInfo.row_set = new int[GlobalV::NLOCAL]; - for(int i=0; iMatrixInfo.row_set = new int[nlocal]; + for(int i=0; iMatrixInfo.row_set[i]=i; } pv->MatrixInfo.row_pos=0; pv->MatrixInfo.col_b = 1; - pv->MatrixInfo.col_num = GlobalV::NLOCAL; + pv->MatrixInfo.col_num = nlocal; delete[] pv->MatrixInfo.col_set; - pv->MatrixInfo.col_set = new int[GlobalV::NLOCAL]; - for(int i=0; iMatrixInfo.col_set = new int[nlocal]; + for(int i=0; iMatrixInfo.col_set[i]=i; } @@ -275,8 +277,8 @@ void ORB_control::divide_HS_2d #endif assert(pv->nloc>0); - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"MatrixInfo.row_num",pv->MatrixInfo.row_num); - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"MatrixInfo.col_num",pv->MatrixInfo.col_num); - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"nloc",pv->nloc); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"MatrixInfo.row_num",pv->MatrixInfo.row_num); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"MatrixInfo.col_num",pv->MatrixInfo.col_num); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"nloc",pv->nloc); return; } diff --git a/source/src_parallel/parallel_orbitals.h b/source/src_parallel/parallel_orbitals.h index 08f3f953b6a..6cced619af2 100644 --- a/source/src_parallel/parallel_orbitals.h +++ b/source/src_parallel/parallel_orbitals.h @@ -11,6 +11,7 @@ struct Parallel_Orbitals { Parallel_Orbitals(); + Parallel_Orbitals(const int& nspin_in); ~Parallel_Orbitals(); /// map from global-index to local-index @@ -57,7 +58,8 @@ struct Parallel_Orbitals #endif /// only used in hpseps-diago - int* loc_sizes; + int nspin; + int* loc_sizes; int loc_size; bool alloc_Z_LOC; //xiaohui add 2014-12-22 double** Z_LOC; //xiaohui add 2014-06-19 @@ -67,7 +69,7 @@ struct Parallel_Orbitals // test parameter int testpb; - + /// check whether a basis element is in this processor /// (check whether local-index > 0 ) bool in_this_processor(const int& iw1_all, const int& iw2_all) const; From 7468328d868c8cbe5e1ed122f725b2fdbae8b2d8 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 28 Feb 2022 10:35:54 +0800 Subject: [PATCH 46/52] UT and annotations for math_bspline --- source/module_base/math_bspline.cpp | 8 +- source/module_base/math_bspline.h | 89 +++++++++---------- source/module_base/test/CMakeLists.txt | 4 + source/module_base/test/math_bspline_test.cpp | 62 +++++++++++++ source/src_pw/bspline_sf.cpp | 10 +-- 5 files changed, 115 insertions(+), 58 deletions(-) create mode 100644 source/module_base/test/math_bspline_test.cpp diff --git a/source/module_base/math_bspline.cpp b/source/module_base/math_bspline.cpp index 378fcd1ed5e..63d95e3f97b 100644 --- a/source/module_base/math_bspline.cpp +++ b/source/module_base/math_bspline.cpp @@ -32,18 +32,12 @@ namespace ModuleBase } } - void Bspline::cleanp() - { - delete[] bezier; - bezier = NULL; - } - double Bspline::bezier_ele(int n) { return this->bezier[n]; } - void Bspline::getbslpine(double x) + void Bspline::getbspline(double x) { bezier[0] = 1.0; for(int k = 1 ; k <= norder ; ++k) diff --git a/source/module_base/math_bspline.h b/source/module_base/math_bspline.h index 2eb72d8880e..86155bb79dc 100644 --- a/source/module_base/math_bspline.h +++ b/source/module_base/math_bspline.h @@ -4,57 +4,54 @@ namespace ModuleBase { - -// -//DESCRIPTION: -// A class to treat Cardinal B-spline interpolation. -// qianrui created 2021-09-14 -//MATH: -// Only uniform nodes are considered: xm-x[m-1]=Dx(>= 0) for control node: X={x0,x1,...,xm}; -// Any function p(x) can be written by -// p(x)=\sum_i{ci*M_ik(x)} (k->infinity), -// where M_ik is the i-th k-order Cardinal B-spline base function -// and ci is undetermined coefficient. -// M_i0 = H(x-xi)-H(x-x[i+1]), H(x): step function -// x-xi x[i+k+1]-x -// M_ik(x)= ---------*M_i(k-1)(x)+ ----------------*M_[i+1][k-1](x) ( xi <= x <= x[i+1] ) -// x[i+k]-xi x[i+k+1]-x[i+1] -// For uniform nodes: M_[i+1]k(x+Dx)=M_ik(x) -// If we define Bk[n] stores M_ik(x+n*Dx) for x in (xi,xi+Dx): -// x+n*Dx-xi xi+(k-n+1)*Dx-x -// Bk[n] = -----------*B(k-1)[n] + -----------------*B(k-1)[n-1] -// k*Dx k*Dx -//USAGE: -// ModuleBase::Bspline bp; -// bp.init(10,0.7,2); //Dx = 0.7, xi = 2 -// bp.getbslpine(0.5); //x = 0.5 -// cout<= 0) for control node: X={x0,x1,...,xm}; + * Any function p(x) can be written by + * p(x)=\sum_i{ci*M_ik(x)} (k->infinity), + * where M_ik is the i-th k-order Cardinal B-spline base function + * and ci is undetermined coefficient. + * M_i0 = H(x-xi)-H(x-x[i+1]), H(x): step function + * x-xi x[i+k+1]-x + * M_ik(x)= ---------*M_i(k-1)(x)+ ----------------*M_[i+1][k-1](x) ( xi <= x <= x[i+1] ) + * x[i+k]-xi x[i+k+1]-x[i+1] + * For uniform nodes: M_[i+1]k(x+Dx)=M_ik(x) + * If we define Bk[n] stores M_ik(x+n*Dx) for x in (xi,xi+Dx): + * x+n*Dx-xi xi+(k-n+1)*Dx-x + * Bk[n] = -----------*B(k-1)[n] + -----------------*B(k-1)[n-1] + * k*Dx k*Dx + * USAGE: + * ModuleBase::Bspline bp; + * bp.init(10,0.7,2); //Dx = 0.7, xi = 2 + * bp.getbslpine(0.5); //x = 0.5 + * cout<= 0 - double Dx; //Dx: the interval of control node - double xi; // xi: the starting point - double *bezier; //bezier[n] = Bk[n] + private: + int norder; // the order of bezier base; norder >= 0 + double Dx; // Dx: the interval of control node + double xi; // xi: the starting point + double *bezier; // bezier[n] = Bk[n] - public: - Bspline(); - ~Bspline(); + public: + Bspline(); + ~Bspline(); - //Init norder, Dx, xi - void init(int norderin, double Dxin, double xiin); + void init(int norderin, double Dxin, double xiin); - //delete[] bezier - void cleanp(); + // Get the result of i-th bezier base functions for different input x+xi+n*Dx. + // x should be in [0,Dx] + // n-th result is stored in bezier[n]; + void getbspline(double x); - //Get the result of i-th bezier base functions for different input x+xi+n*Dx. - //x should be in [0,Dx] - //n-th result is stored in bezier[n]; - void getbslpine(double x); - - //get the element of bezier - double bezier_ele(int n); + // get the element of bezier + double bezier_ele(int n); }; -} +} // namespace ModuleBase #endif diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index 7fa8a6c5e2c..46359f2b5eb 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -82,3 +82,7 @@ AddTest( LIBS ${math_libs} SOURCES math_ylmreal_test.cpp ../math_ylmreal.cpp ../realarray.cpp ../timer.cpp ../matrix.cpp ) +AddTest( + TARGET base_math_bspline + SOURCES math_bspline_test.cpp ../math_bspline.cpp +) diff --git a/source/module_base/test/math_bspline_test.cpp b/source/module_base/test/math_bspline_test.cpp new file mode 100644 index 00000000000..2de14b42826 --- /dev/null +++ b/source/module_base/test/math_bspline_test.cpp @@ -0,0 +1,62 @@ +#include "../math_bspline.h" +#include "gtest/gtest.h" + +/************************************************ + * unit test of class Bspline + ***********************************************/ + +/** + * - Tested Functions: + * - Init + * - norder must be even + * - norder mush be positive + * - Properties + * - \sum_i M_n(u+i) = 1 (i=0,1,2,...n) + * + */ + +class MathBsplineTest : public testing::Test +{ +protected: + ModuleBase::Bspline bp; + int norder; +}; + +TEST_F(MathBsplineTest,Init) +{ + EXPECT_DEATH( + { + norder = 3; // norder must be even + bp.init(norder,0.05,0); + },"" + ); + EXPECT_DEATH( + { + norder = 0; // norder must be positive + bp.init(norder,0.05,0); + },"" + ); +} + +// summation over n is unity +TEST_F(MathBsplineTest,Properties) +{ + int by = 2; + for (norder=2;norder<=20;norder=norder+by) + { + bp.init(norder,1.0,0); + bp.getbspline(0.2); + double sum=0.0; + //std::cout << "\n" << "norder : "<< norder< *b1, complex *b2, complex ci_tpi = ModuleBase::NEG_IMAG_UNIT * ModuleBase::TWO_PI; ModuleBase::Bspline bsp; bsp.init(norder, 1, 0); - bsp.getbslpine(1.0); + bsp.getbspline(1.0); for(int ix = 0 ; ix < this->nx ; ++ix) { complex fracx=0; @@ -137,4 +137,4 @@ void PW_Basis:: bsplinecoef(complex *b1, complex *b2, complex Date: Mon, 28 Feb 2022 10:46:41 +0800 Subject: [PATCH 47/52] move parallel_orbitals.h, .cpp files from src_parallel into module_orbital --- source/module_md/MD_func.h | 2 +- source/module_orbital/CMakeLists.txt | 1 + source/module_orbital/ORB_control.h | 3 +-- source/{src_parallel => module_orbital}/parallel_orbitals.cpp | 0 source/{src_parallel => module_orbital}/parallel_orbitals.h | 3 +-- source/module_orbital/test/CMakeLists.txt | 2 +- source/src_lcao/LCAO_matrix.h | 2 +- source/src_lcao/LOOP_ions.cpp | 2 +- source/src_lcao/dftu.h | 2 +- source/src_lcao/record_adj.h | 2 +- source/src_lcao/run_md_lcao.cpp | 2 +- source/src_parallel/CMakeLists.txt | 1 - source/src_pdiag/pdiag_double.h | 2 +- source/src_ri/exx_abfs-parallel-communicate-dm3.h | 2 +- source/src_ri/exx_abfs-parallel-communicate-function.h | 2 +- source/src_ri/exx_abfs-parallel-communicate-hexx.h | 2 +- 16 files changed, 14 insertions(+), 16 deletions(-) rename source/{src_parallel => module_orbital}/parallel_orbitals.cpp (100%) rename source/{src_parallel => module_orbital}/parallel_orbitals.h (96%) diff --git a/source/module_md/MD_func.h b/source/module_md/MD_func.h index cccbaf3dd46..d313a903cf1 100644 --- a/source/module_md/MD_func.h +++ b/source/module_md/MD_func.h @@ -5,7 +5,7 @@ #include "../module_cell/unitcell_pseudo.h" #include "../module_base/matrix.h" #ifdef __LCAO -#include "../src_parallel/parallel_orbitals.h" +#include "../module_orbital/parallel_orbitals.h" #endif #include "module_ensolver/en_solver.h" diff --git a/source/module_orbital/CMakeLists.txt b/source/module_orbital/CMakeLists.txt index bd630fa4f0e..f5e4157844e 100644 --- a/source/module_orbital/CMakeLists.txt +++ b/source/module_orbital/CMakeLists.txt @@ -12,6 +12,7 @@ add_library( ORB_table_alpha.cpp ORB_table_beta.cpp ORB_table_phi.cpp + parallel_orbitals.cpp ) set(CMAKE_CXX_STANDARD_REQUIRED ON) diff --git a/source/module_orbital/ORB_control.h b/source/module_orbital/ORB_control.h index 0821aa06f27..de72021fee3 100644 --- a/source/module_orbital/ORB_control.h +++ b/source/module_orbital/ORB_control.h @@ -1,8 +1,7 @@ #ifndef ORB_CONTROL_H #define ORB_CONTROL_H -#include "src_parallel/parallel_orbitals.h" -#include "src_parallel/parallel_global.h" +#include "parallel_orbitals.h" #include "ORB_gen_tables.h" #include "ORB_read.h" diff --git a/source/src_parallel/parallel_orbitals.cpp b/source/module_orbital/parallel_orbitals.cpp similarity index 100% rename from source/src_parallel/parallel_orbitals.cpp rename to source/module_orbital/parallel_orbitals.cpp diff --git a/source/src_parallel/parallel_orbitals.h b/source/module_orbital/parallel_orbitals.h similarity index 96% rename from source/src_parallel/parallel_orbitals.h rename to source/module_orbital/parallel_orbitals.h index 6cced619af2..4d4862bb06b 100644 --- a/source/src_parallel/parallel_orbitals.h +++ b/source/module_orbital/parallel_orbitals.h @@ -11,7 +11,6 @@ struct Parallel_Orbitals { Parallel_Orbitals(); - Parallel_Orbitals(const int& nspin_in); ~Parallel_Orbitals(); /// map from global-index to local-index @@ -58,7 +57,7 @@ struct Parallel_Orbitals #endif /// only used in hpseps-diago - int nspin; + int nspin = 1; int* loc_sizes; int loc_size; bool alloc_Z_LOC; //xiaohui add 2014-12-22 diff --git a/source/module_orbital/test/CMakeLists.txt b/source/module_orbital/test/CMakeLists.txt index 96b7fb5fbcf..c4f71f09fa1 100644 --- a/source/module_orbital/test/CMakeLists.txt +++ b/source/module_orbital/test/CMakeLists.txt @@ -34,8 +34,8 @@ list(APPEND depend_files ../ORB_table_phi.cpp ../ORB_table_alpha.cpp ../ORB_gen_tables.cpp + ../parallel_orbitals.cpp ../../src_lcao/center2_orb-orb11.cpp - ../../src_parallel/parallel_orbitals.cpp ) AddTest( TARGET orbital_equal_test diff --git a/source/src_lcao/LCAO_matrix.h b/source/src_lcao/LCAO_matrix.h index 30e619005e3..53fce2a94fa 100644 --- a/source/src_lcao/LCAO_matrix.h +++ b/source/src_lcao/LCAO_matrix.h @@ -5,7 +5,7 @@ #include "../module_base/global_variable.h" #include "../module_base/vector3.h" #include "../module_base/complexmatrix.h" -#include "../src_parallel/parallel_orbitals.h" +#include "../module_orbital/parallel_orbitals.h" // add by jingan for map<> in 2021-12-2, will be deleted in the future #include "../src_ri/abfs-vector3_order.h" diff --git a/source/src_lcao/LOOP_ions.cpp b/source/src_lcao/LOOP_ions.cpp index b38c54f52dc..64d85fbde0a 100644 --- a/source/src_lcao/LOOP_ions.cpp +++ b/source/src_lcao/LOOP_ions.cpp @@ -1,6 +1,6 @@ #include "LOOP_ions.h" #include "../src_pw/global.h" -#include "../src_parallel/parallel_orbitals.h" +#include "../module_orbital/parallel_orbitals.h" #include "../src_pdiag/pdiag_double.h" #include "FORCE_STRESS.h" #include "../module_base/global_function.h" diff --git a/source/src_lcao/dftu.h b/source/src_lcao/dftu.h index 6e51f1cba86..e97e4db5849 100644 --- a/source/src_lcao/dftu.h +++ b/source/src_lcao/dftu.h @@ -11,7 +11,7 @@ #include "dftu_relax.h" #include "../module_cell/unitcell_pseudo.h" -#include "../src_parallel/parallel_orbitals.h" +#include "../module_orbital/parallel_orbitals.h" using namespace std; diff --git a/source/src_lcao/record_adj.h b/source/src_lcao/record_adj.h index 5f93f022194..05eb8c03722 100644 --- a/source/src_lcao/record_adj.h +++ b/source/src_lcao/record_adj.h @@ -2,7 +2,7 @@ #define RECORD_ADJ_H #include "grid_technique.h" -#include "src_parallel/parallel_orbitals.h" +#include "module_orbital/parallel_orbitals.h" //--------------------------------------------------- diff --git a/source/src_lcao/run_md_lcao.cpp b/source/src_lcao/run_md_lcao.cpp index fdc49e85a8f..10575077e24 100644 --- a/source/src_lcao/run_md_lcao.cpp +++ b/source/src_lcao/run_md_lcao.cpp @@ -5,7 +5,7 @@ #include "../src_pw/vdwd2.h" #include "../src_pw/vdwd2_parameters.h" #include "../src_pw/vdwd3_parameters.h" -#include "../src_parallel/parallel_orbitals.h" +#include "../module_orbital/parallel_orbitals.h" #include "../src_pdiag/pdiag_double.h" #include "../src_io/write_HS.h" #include "../src_io/cal_r_overlap_R.h" diff --git a/source/src_parallel/CMakeLists.txt b/source/src_parallel/CMakeLists.txt index 0eb4a88689f..a3b9122edd0 100644 --- a/source/src_parallel/CMakeLists.txt +++ b/source/src_parallel/CMakeLists.txt @@ -4,7 +4,6 @@ list(APPEND objects parallel_global.cpp parallel_grid.cpp parallel_kpoints.cpp - parallel_orbitals.cpp parallel_pw.cpp parallel_reduce.cpp subgrid_oper.cpp diff --git a/source/src_pdiag/pdiag_double.h b/source/src_pdiag/pdiag_double.h index 247aac58f63..d22efee798a 100644 --- a/source/src_pdiag/pdiag_double.h +++ b/source/src_pdiag/pdiag_double.h @@ -6,7 +6,7 @@ #include "../module_base/matrix.h" #include "../module_base/complexmatrix.h" #include "diag_scalapack_gvx.h" -#include "src_parallel/parallel_orbitals.h" +#include "module_orbital/parallel_orbitals.h" #include "src_lcao/local_orbital_wfc.h" class Pdiag_Double diff --git a/source/src_ri/exx_abfs-parallel-communicate-dm3.h b/source/src_ri/exx_abfs-parallel-communicate-dm3.h index 43e30dd9f58..98851229b15 100644 --- a/source/src_ri/exx_abfs-parallel-communicate-dm3.h +++ b/source/src_ri/exx_abfs-parallel-communicate-dm3.h @@ -4,7 +4,7 @@ #include "exx_abfs-parallel.h" #include "abfs-vector3_order.h" #include "../module_base/complexmatrix.h" -#include "src_parallel/parallel_orbitals.h" +#include "module_orbital/parallel_orbitals.h" #include "src_lcao/local_orbital_charge.h" #ifdef __MPI #include "mpi.h" diff --git a/source/src_ri/exx_abfs-parallel-communicate-function.h b/source/src_ri/exx_abfs-parallel-communicate-function.h index 215cf095e2b..0db2427fc61 100644 --- a/source/src_ri/exx_abfs-parallel-communicate-function.h +++ b/source/src_ri/exx_abfs-parallel-communicate-function.h @@ -3,7 +3,7 @@ #define EXX_ABFS_PARALLEL_COMMUNICATE_FUNCTION_H #include "exx_abfs-parallel.h" -#include "src_parallel/parallel_orbitals.h" +#include "module_orbital/parallel_orbitals.h" #include #include #ifdef __MPI diff --git a/source/src_ri/exx_abfs-parallel-communicate-hexx.h b/source/src_ri/exx_abfs-parallel-communicate-hexx.h index f33bb4e8e45..aca55569d4b 100644 --- a/source/src_ri/exx_abfs-parallel-communicate-hexx.h +++ b/source/src_ri/exx_abfs-parallel-communicate-hexx.h @@ -13,7 +13,7 @@ #include #include "mpi.h" #include -#include "src_parallel/parallel_orbitals.h" +#include "module_orbital/parallel_orbitals.h" // mohan comment out 2021-02-06 //#include From eea33c5d6728381f96f98e38e7a17ee5bb82c33a Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Mon, 28 Feb 2022 11:18:44 +0800 Subject: [PATCH 48/52] remove newdm parameters in test case 260 --- tests/integrate/260_NO_15_PK_PU_AF/INPUT | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integrate/260_NO_15_PK_PU_AF/INPUT b/tests/integrate/260_NO_15_PK_PU_AF/INPUT index e18597a2876..3022e1417a5 100644 --- a/tests/integrate/260_NO_15_PK_PU_AF/INPUT +++ b/tests/integrate/260_NO_15_PK_PU_AF/INPUT @@ -31,7 +31,6 @@ basis_type lcao gamma_only 1 symmetry 0 nspin 2 -newdm 1 #Parameter DFT+U dft_plus_u 1 From 61bef0129cb4e5a0f03a0d61903ab5535fc3292a Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Mon, 28 Feb 2022 11:25:05 +0800 Subject: [PATCH 49/52] remove useless output --- source/module_orbital/ORB_control.cpp | 2 -- source/src_pdiag/pdiag_double.cpp | 2 -- 2 files changed, 4 deletions(-) diff --git a/source/module_orbital/ORB_control.cpp b/source/module_orbital/ORB_control.cpp index 51cb4fb0588..d0580987816 100644 --- a/source/module_orbital/ORB_control.cpp +++ b/source/module_orbital/ORB_control.cpp @@ -508,8 +508,6 @@ void ORB_control::mat_2d(MPI_Comm vu, } pv->nloc_wfc = pv->ncol_bands * LM.row_num; - std::cout << pv->nloc_wfc << " " << pv->ncol_bands << " " << LM.row_num << std::endl; - return; } #endif diff --git a/source/src_pdiag/pdiag_double.cpp b/source/src_pdiag/pdiag_double.cpp index 29f64b3ab42..0e538e55a17 100644 --- a/source/src_pdiag/pdiag_double.cpp +++ b/source/src_pdiag/pdiag_double.cpp @@ -221,9 +221,7 @@ void Pdiag_Double::diago_double_begin( ModuleBase::timer::tick("Diago_LCAO_Matrix","elpa_solve"); int elpa_error; - std::cout << "before elpa" << std::endl; elpa_generalized_eigenvectors_d(handle, h_mat, Stmp, eigen, lowf.wfc_gamma[ik].c, is_already_decomposed, &elpa_error); - std::cout << "after elpa" << std::endl; ModuleBase::timer::tick("Diago_LCAO_Matrix", "elpa_solve"); ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"K-S equation was solved by genelpa2"); From ebd44d5479676b7c92ee635d0a91b1bd05ac804f Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Mon, 28 Feb 2022 14:29:39 +0800 Subject: [PATCH 50/52] apply Ensolver to lcao-line --- source/driver.cpp | 9 +- source/module_ensolver/CMakeLists.txt | 1 + .../FP/KSDFT/LCAO/ks_scf_lcao.cpp | 136 ++++++++++++++++++ .../FP/KSDFT/LCAO/ks_scf_lcao.h | 39 +++++ .../module_ensolver/FP/KSDFT/PW/ks_scf_pw.h | 9 +- source/module_ensolver/FP/KSDFT/ks_scf.h | 5 +- source/module_ensolver/Makefile.ensolver | 6 +- source/module_ensolver/en_solver.cpp | 7 +- source/module_ensolver/en_solver.h | 18 ++- source/module_orbital/ORB_control.cpp | 2 +- source/run_lcao.cpp | 112 +++------------ source/src_lcao/LOOP_cell.cpp | 4 +- source/src_lcao/LOOP_cell.h | 3 +- source/src_lcao/LOOP_ions.cpp | 7 +- source/src_lcao/LOOP_ions.h | 3 +- 15 files changed, 244 insertions(+), 117 deletions(-) create mode 100644 source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.cpp create mode 100644 source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.h diff --git a/source/driver.cpp b/source/driver.cpp index 4a8fe8af3fe..42412311a2d 100644 --- a/source/driver.cpp +++ b/source/driver.cpp @@ -104,9 +104,12 @@ void Driver::atomic_world(void) } #ifdef __LCAO else if(GlobalV::BASIS_TYPE=="lcao") - { - Run_lcao::lcao_line(p_ensolver); - } + { + use_ensol = "ksdft_lcao"; + ModuleEnSover::init_esolver(p_ensolver, use_ensol); + Run_lcao::lcao_line(p_ensolver); + ModuleEnSover::clean_esolver(p_ensolver); + } #endif ModuleBase::timer::finish( GlobalV::ofs_running ); diff --git a/source/module_ensolver/CMakeLists.txt b/source/module_ensolver/CMakeLists.txt index 8f9d75dfcb0..f24a61914ab 100644 --- a/source/module_ensolver/CMakeLists.txt +++ b/source/module_ensolver/CMakeLists.txt @@ -5,6 +5,7 @@ add_library( FP/ab_initio.cpp FP/KSDFT/ks_scf.cpp FP/KSDFT/PW/ks_scf_pw.cpp + FP/KSDFT/LCAO/ks_scf_lcao.cpp ) diff --git a/source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.cpp b/source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.cpp new file mode 100644 index 00000000000..3aad41d62ca --- /dev/null +++ b/source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.cpp @@ -0,0 +1,136 @@ +#include "ks_scf_lcao.h" + +//--------------temporary---------------------------- +#include "../../../../src_pw/global.h" +#include "../../../../module_base/global_function.h" +#include "src_io/print_info.h" + +#ifdef __DEEPKS +#include "module_deepks/LCAO_deepks.h" +#endif +//-----force------------------- + +//-----stress------------------ + +//--------------------------------------------------- + +namespace ModuleEnSover +{ + +void KS_SCF_LCAO::Init(Input &inp, UnitCell_pseudo &ucell) +{ + + // setup GlobalV::NBANDS + // Yu Liu add 2021-07-03 + GlobalC::CHR.cal_nelec(); + + // mohan add 2010-09-06 + // Yu Liu move here 2021-06-27 + // because the number of element type + // will easily be ignored, so here + // I warn the user again for each type. + for(int it=0; itLOE.solve_elec_stru(istep, ra, loc, lowf, uhm); + return ; +} + +void KS_SCF_LCAO::cal_Energy(energy &en) +{ + +} + +void KS_SCF_LCAO::cal_Force(ModuleBase::matrix &force) +{ + +} +void KS_SCF_LCAO::cal_Stress(ModuleBase::matrix &stress) +{ + +} + +} \ No newline at end of file diff --git a/source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.h b/source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.h new file mode 100644 index 00000000000..0f9d36d0785 --- /dev/null +++ b/source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.h @@ -0,0 +1,39 @@ +#ifndef KS_SCF_LCAO_H +#define KS_SCF_LCAO_H +#include "../ks_scf.h" + +#include "src_lcao/LOOP_elec.h" + +namespace ModuleEnSover +{ + +class KS_SCF_LCAO: public KS_SCF +{ +public: + KS_SCF_LCAO() + { + tag = "KS_SCF_LCAO"; + } + void Init(Input& inp, UnitCell_pseudo& cell) override; + + void Run(int istep, + Record_adj& ra, + Local_Orbital_Charge& loc, + Local_Orbital_wfc& lowf, + LCAO_Hamilt& uhm) override; + void Run(int istep, UnitCell_pseudo& cell) override {}; + + void cal_Energy(energy& en) override; + void cal_Force(ModuleBase::matrix &force) override; + void cal_Stress(ModuleBase::matrix& stress) override; + +private: + LOOP_elec LOE; +}; + +///Basis_lcao +///ORB_control orb_con; + + +} +#endif \ No newline at end of file diff --git a/source/module_ensolver/FP/KSDFT/PW/ks_scf_pw.h b/source/module_ensolver/FP/KSDFT/PW/ks_scf_pw.h index a42e2dffb89..1815053be15 100644 --- a/source/module_ensolver/FP/KSDFT/PW/ks_scf_pw.h +++ b/source/module_ensolver/FP/KSDFT/PW/ks_scf_pw.h @@ -19,8 +19,13 @@ class KS_SCF_PW: public KS_SCF tag = "KS_SCF_PW"; } void Init(Input &inp, UnitCell_pseudo &cell) override; - void Run(int istep, UnitCell_pseudo &cell) override; - void cal_Energy(energy &en) override; + void Run(int istep, UnitCell_pseudo& cell) override; + void Run(int istep, + Record_adj& ra, + Local_Orbital_Charge& loc, + Local_Orbital_wfc& lowf, + LCAO_Hamilt& uhm) override {}; + void cal_Energy(energy& en) override; void cal_Force(ModuleBase::matrix &force) override; void cal_Stress(ModuleBase::matrix &stress) override; diff --git a/source/module_ensolver/FP/KSDFT/ks_scf.h b/source/module_ensolver/FP/KSDFT/ks_scf.h index fd3acbe20c3..160eca6b746 100644 --- a/source/module_ensolver/FP/KSDFT/ks_scf.h +++ b/source/module_ensolver/FP/KSDFT/ks_scf.h @@ -1,3 +1,5 @@ +#ifndef KS_SCF_H +#define KS_SCF_H #include "../ab_initio.h" // #include "estates.h" // #include "h2e.h" @@ -11,4 +13,5 @@ class KS_SCF: public ab_initio // Estate *p_es; // H2E *p_h2e; }; -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/source/module_ensolver/Makefile.ensolver b/source/module_ensolver/Makefile.ensolver index ccc4e33cdfd..bec2e80b169 100644 --- a/source/module_ensolver/Makefile.ensolver +++ b/source/module_ensolver/Makefile.ensolver @@ -2,12 +2,14 @@ VPATH:=$(VPATH)\ :./module_ensolver\ :./module_ensolver/FP\ :./module_ensolver/FP/KSDFT\ -:./module_ensolver/FP/KSDFT/PW +:./module_ensolver/FP/KSDFT/PW\ +:./module_ensolver/FP/KSDFT/LCAO\ OBJS_ENSOLVER=en_solver.o\ ab_initio.o\ ks_scf.o\ -ks_scf_pw.o +ks_scf_pw.o\ +ks_scf_lcao.o OBJS_FIRST_PRINCIPLES:= ${OBJS_FIRST_PRINCIPLES} \ ${OBJS_ENSOLVER} \ No newline at end of file diff --git a/source/module_ensolver/en_solver.cpp b/source/module_ensolver/en_solver.cpp index e8458240687..0457f94a241 100644 --- a/source/module_ensolver/en_solver.cpp +++ b/source/module_ensolver/en_solver.cpp @@ -1,5 +1,6 @@ #include "en_solver.h" #include "FP/KSDFT/PW/ks_scf_pw.h" +#include "FP/KSDFT/LCAO/ks_scf_lcao.h" #include "FP/OFDFT/ofdft.h" #include "stdio.h" namespace ModuleEnSover @@ -17,7 +18,11 @@ void init_esolver(En_Solver *&p_ensolver, const string use_esol) { p_ensolver = new KS_SCF_PW(); } - // else if(use_esol == "sdft_pw") + else if(use_esol == "ksdft_lcao") + { + p_ensolver = new KS_SCF_LCAO(); + } + // else if(use_esol == "sdft_pw") // { // p_ensolver = new KS_SCF_PW(true); // } diff --git a/source/module_ensolver/en_solver.h b/source/module_ensolver/en_solver.h index f2ec6a0510b..d5e94d7525c 100644 --- a/source/module_ensolver/en_solver.h +++ b/source/module_ensolver/en_solver.h @@ -5,6 +5,12 @@ #include "../module_cell/unitcell_pseudo.h" #include "../src_pw/energy.h" #include "../module_base/matrix.h" +//--------------temporary---------------------------- +#include "src_lcao/record_adj.h" +#include "src_lcao/local_orbital_charge.h" +#include "src_lcao/local_orbital_wfc.h" +#include "src_lcao/LCAO_hamilt.h" +//--------------\temporary---------------------------- namespace ModuleEnSover { @@ -18,10 +24,16 @@ class En_Solver //virtual void Init(Input_EnSolver &inp, matrix &lattice_v)=0 virtual void Init(Input &inp, UnitCell_pseudo &cell)=0; - + + /// These two virtual `Run` will be merged in the future. //virtual void Run(int istep, Atom &atom) = 0; - virtual void Run(int istep, UnitCell_pseudo &cell) = 0; - + virtual void Run(int istep, UnitCell_pseudo& cell) = 0; + virtual void Run(int istep, + Record_adj& ra /**< would be a 2nd-module of Cell*/, + Local_Orbital_Charge& loc /**< EState*/, + Local_Orbital_wfc& lowf /**< Psi*/, + LCAO_Hamilt& uhm /**< Hamilt*/) = 0; + virtual void cal_Energy(energy &en) = 0; virtual void cal_Force(ModuleBase::matrix &force) = 0; virtual void cal_Stress(ModuleBase::matrix &stress) = 0; diff --git a/source/module_orbital/ORB_control.cpp b/source/module_orbital/ORB_control.cpp index d0580987816..933c15d04dc 100644 --- a/source/module_orbital/ORB_control.cpp +++ b/source/module_orbital/ORB_control.cpp @@ -204,7 +204,7 @@ void ORB_control::set_parameters(std::ofstream& ofs_running, if(pv->loc_size==0) { ofs_warning << " loc_size=0" << " in proc " << myrank+1 << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < dsize"); + ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < DSIZE"); } if (drankloc_size+=1; diff --git a/source/run_lcao.cpp b/source/run_lcao.cpp index 50fe0df7da2..4d005dfac14 100644 --- a/source/run_lcao.cpp +++ b/source/run_lcao.cpp @@ -7,11 +7,7 @@ #include "module_neighbor/sltk_atom_arrange.h" #include "src_lcao/LOOP_cell.h" #include "src_io/print_info.h" -#include "module_symmetry/symmetry.h" #include "src_lcao/run_md_lcao.h" -#ifdef __DEEPKS -#include "module_deepks/LCAO_deepks.h" -#endif Run_lcao::Run_lcao(){} Run_lcao::~Run_lcao(){} @@ -21,7 +17,7 @@ void Run_lcao::lcao_line(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_lcao","lcao_line"); ModuleBase::timer::tick("Run_lcao", "lcao_line"); - + //------------------------------------init Cell---------------------------- // Setup the unitcell. // improvement: a) separating the first reading of the atom_card and subsequent // cell relaxation. b) put GlobalV::NLOCAL and GlobalV::NBANDS as input parameters @@ -48,39 +44,22 @@ void Run_lcao::lcao_line(ModuleEnSover::En_Solver *p_ensolver) GlobalV::SEARCH_RADIUS, GlobalV::test_atom_input, INPUT.test_just_neighbor); - } - // setup GlobalV::NBANDS - // Yu Liu add 2021-07-03 - GlobalC::CHR.cal_nelec(); - - // mohan add 2010-09-06 - // Yu Liu move here 2021-06-27 - // because the number of element type - // will easily be ignored, so here - // I warn the user again for each type. - for(int it=0; itInit(INPUT, GlobalC::ucell); + //------------------------------------------------------------ + //------------------------------------init Basis_lcao---------------------------- // * reading the localized orbitals/projectors // * construct the interpolation tables. ORB_control orb_con( @@ -90,7 +69,7 @@ void Run_lcao::lcao_line(ModuleEnSover::En_Solver *p_ensolver) GlobalV::NB2D, GlobalV::DCOLOR, GlobalV::DRANK, GlobalV::MY_RANK, GlobalV::CALCULATION, GlobalV::KS_SOLVER); - + orb_con.read_orb_first( GlobalV::ofs_running, GlobalC::ORB, @@ -126,68 +105,9 @@ void Run_lcao::lcao_line(ModuleEnSover::En_Solver *p_ensolver) #endif orb_con.setup_2d_division(GlobalV::ofs_running, GlobalV::ofs_warning); -//-------------------------------------- -// cell relaxation should begin here -//-------------------------------------- - - // Initalize the plane wave basis set - GlobalC::pw.gen_pw(GlobalV::ofs_running, GlobalC::ucell, GlobalC::kv); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running,"INIT PLANEWAVE"); - std::cout << " UNIFORM GRID DIM : " << GlobalC::pw.nx <<" * " << GlobalC::pw.ny <<" * "<< GlobalC::pw.nz << std::endl; - std::cout << " UNIFORM GRID DIM(BIG): " << GlobalC::pw.nbx <<" * " << GlobalC::pw.nby <<" * "<< GlobalC::pw.nbz << std::endl; - - // the symmetry of a variety of systems. - if(GlobalV::CALCULATION == "test") - { - Cal_Test::test_memory(); - ModuleBase::QUIT(); - } - - // initialize the real-space uniform grid for FFT and parallel - // distribution of plane waves - GlobalC::Pgrid.init(GlobalC::pw.ncx, GlobalC::pw.ncy, GlobalC::pw.ncz, GlobalC::pw.nczp, - GlobalC::pw.nrxx, GlobalC::pw.nbz, GlobalC::pw.bz); // mohan add 2010-07-22, update 2011-05-04 - // Calculate Structure factor - GlobalC::pw.setup_structure_factor(); - - // Inititlize the charge density. - GlobalC::CHR.allocate(GlobalV::NSPIN, GlobalC::pw.nrxx, GlobalC::pw.ngmc); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running,"INIT CHARGE"); - - // Initializee the potential. - GlobalC::pot.allocate(GlobalC::pw.nrxx); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running,"INIT POTENTIAL"); - - - // Peize Lin add 2018-11-30 -#ifdef __MPI - if(GlobalV::CALCULATION=="nscf") - { - switch(GlobalC::exx_global.info.hybrid_type) - { - case Exx_Global::Hybrid_Type::HF: - case Exx_Global::Hybrid_Type::PBE0: - case Exx_Global::Hybrid_Type::HSE: - GlobalC::exx_global.info.set_xcfunc(GlobalC::xcf); - break; - } - } -#endif - -#ifdef __DEEPKS - //wenfei 2021-12-19 - //if we are performing DeePKS calculations, we need to load a model - if (GlobalV::out_descriptor) - { - if (GlobalV::deepks_scf) - { - // load the DeePKS model from deep neural network - GlobalC::ld.load_model(INPUT.model_file); - } - } -#endif + //------------------------------------\init Basis_lcao---------------------------- - if(GlobalV::CALCULATION=="md") + if (GlobalV::CALCULATION == "md") { Run_MD_LCAO run_md_lcao(orb_con.ParaV); run_md_lcao.opt_cell(orb_con, p_ensolver); @@ -196,7 +116,7 @@ void Run_lcao::lcao_line(ModuleEnSover::En_Solver *p_ensolver) { LOOP_cell lc(orb_con.ParaV); //keep wfc_gamma or wfc_k remaining - lc.opt_cell(orb_con); + lc.opt_cell(orb_con, p_ensolver); } ModuleBase::timer::tick("Run_lcao","lcao_line"); diff --git a/source/src_lcao/LOOP_cell.cpp b/source/src_lcao/LOOP_cell.cpp index 50fd9ab4d42..0474e59e0ed 100644 --- a/source/src_lcao/LOOP_cell.cpp +++ b/source/src_lcao/LOOP_cell.cpp @@ -15,7 +15,7 @@ LOOP_cell::LOOP_cell(Parallel_Orbitals &pv) } LOOP_cell::~LOOP_cell() {} -void LOOP_cell::opt_cell(ORB_control &orb_con) +void LOOP_cell::opt_cell(ORB_control &orb_con, ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("LOOP_cell","opt_cell"); @@ -73,7 +73,7 @@ void LOOP_cell::opt_cell(ORB_control &orb_con) if(INPUT.dft_plus_dmft) GlobalC::dmft.init(INPUT, GlobalC::ucell); LOOP_ions ions(this->LM); - ions.opt_ions(); + ions.opt_ions(p_ensolver); // mohan update 2021-02-10 orb_con.clear_after_ions(GlobalC::UOT, GlobalC::ORB, GlobalV::out_descriptor, GlobalC::ucell.infoNL.nproj); diff --git a/source/src_lcao/LOOP_cell.h b/source/src_lcao/LOOP_cell.h index a4f8dedd057..b5bc6ba153a 100644 --- a/source/src_lcao/LOOP_cell.h +++ b/source/src_lcao/LOOP_cell.h @@ -5,6 +5,7 @@ #include "module_base/complexmatrix.h" #include "module_orbital/ORB_control.h" #include "src_lcao/LCAO_matrix.h" +#include "module_ensolver/en_solver.h" class LOOP_cell { @@ -13,7 +14,7 @@ class LOOP_cell LOOP_cell(Parallel_Orbitals &pv); ~LOOP_cell(); - void opt_cell(ORB_control &orb_con); + void opt_cell(ORB_control &orb_con, ModuleEnSover::En_Solver *p_ensolver); private: LCAO_Matrix LM; diff --git a/source/src_lcao/LOOP_ions.cpp b/source/src_lcao/LOOP_ions.cpp index 64d85fbde0a..02efbcce737 100644 --- a/source/src_lcao/LOOP_ions.cpp +++ b/source/src_lcao/LOOP_ions.cpp @@ -35,7 +35,7 @@ LOOP_ions::LOOP_ions(LCAO_Matrix &lm) LOOP_ions::~LOOP_ions() {} -void LOOP_ions::opt_ions() +void LOOP_ions::opt_ions(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("LOOP_ions","opt_ions"); ModuleBase::timer::tick("LOOP_ions","opt_ions"); @@ -151,9 +151,8 @@ void LOOP_ions::opt_ions() Record_adj RA; // solve electronic structures in terms of LCAO - // mohan add 2021-02-09 - LOE.solve_elec_stru(this->istep, RA, this->LOC, this->LOWF, this->UHM); - + // mohan add 2021-02-09 + p_ensolver->Run(this->istep, RA, this->LOC, this->LOWF, this->UHM); time_t eend = time(NULL); diff --git a/source/src_lcao/LOOP_ions.h b/source/src_lcao/LOOP_ions.h index d30b4c5d0af..a9ec64622e1 100644 --- a/source/src_lcao/LOOP_ions.h +++ b/source/src_lcao/LOOP_ions.h @@ -8,6 +8,7 @@ #include "src_lcao/local_orbital_wfc.h" #include "module_orbital/ORB_control.h" #include "src_lcao/LCAO_hamilt.h" +#include "module_ensolver/en_solver.h" #include @@ -24,7 +25,7 @@ class LOOP_ions Local_Orbital_Charge LOC; LCAO_Hamilt UHM; - void opt_ions(); //output for dos + void opt_ions(ModuleEnSover::En_Solver *p_ensolver); //output for dos void output_HS_R( const std::string &SR_filename="data-SR-sparse_SPIN0.csr", const std::string &HR_filename_up="data-HR-sparse_SPIN0.csr", From b787150577a099f015e5378fd9203908105a657a Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Mon, 28 Feb 2022 17:02:15 +0800 Subject: [PATCH 51/52] fix a pointer-bug in DFTU --- source/src_lcao/dftu.h | 4 ---- source/src_lcao/dftu_relax.h | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/source/src_lcao/dftu.h b/source/src_lcao/dftu.h index e97e4db5849..29fc02e06b7 100644 --- a/source/src_lcao/dftu.h +++ b/source/src_lcao/dftu.h @@ -62,10 +62,6 @@ class DFTU : public DFTU_RELAX double EU; int iter_dftu; - -private: - LCAO_Matrix* LM; - }; } namespace GlobalC diff --git a/source/src_lcao/dftu_relax.h b/source/src_lcao/dftu_relax.h index 7d49e4d3b14..d0672304af8 100644 --- a/source/src_lcao/dftu_relax.h +++ b/source/src_lcao/dftu_relax.h @@ -64,7 +64,7 @@ class DFTU_RELAX : public DFTU_Yukawa //locale_save: the input local occupation number matrix of correlated electrons in the current electronic step std::vector>>> locale; // locale[iat][l][n][spin](m1,m2) std::vector>>> locale_save; // locale_save[iat][l][n][spin](m1,m2) -private: +protected: LCAO_Matrix* LM; }; } From 82b349ec1954e0f90abbfa3f0dc3e88fb6d34990 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Mon, 28 Feb 2022 21:42:15 +0800 Subject: [PATCH 52/52] rearange init-basis in Run_lcao --- source/module_orbital/ORB_control.cpp | 6 +- source/module_orbital/ORB_control.h | 6 ++ source/run_lcao.cpp | 92 +++++++++++++++------------ source/run_lcao.h | 2 + 4 files changed, 64 insertions(+), 42 deletions(-) diff --git a/source/module_orbital/ORB_control.cpp b/source/module_orbital/ORB_control.cpp index 933c15d04dc..c77e3f3f4bd 100644 --- a/source/module_orbital/ORB_control.cpp +++ b/source/module_orbital/ORB_control.cpp @@ -30,12 +30,14 @@ ORB_control::ORB_control( drank(drank_in), myrank(myrank_in), calculation(calculation_in), - ks_solver(ks_solver_in) + ks_solver(ks_solver_in), + setup_2d(true) { this->ParaV.nspin = nspin_in; } -ORB_control::ORB_control() +ORB_control::ORB_control() : + setup_2d(false) {} ORB_control::~ORB_control() {} diff --git a/source/module_orbital/ORB_control.h b/source/module_orbital/ORB_control.h index de72021fee3..2595cab3926 100644 --- a/source/module_orbital/ORB_control.h +++ b/source/module_orbital/ORB_control.h @@ -1,6 +1,8 @@ #ifndef ORB_CONTROL_H #define ORB_CONTROL_H +#include "input.h" +#include "module_cell/unitcell_pseudo.h" #include "parallel_orbitals.h" #include "ORB_gen_tables.h" #include "ORB_read.h" @@ -30,6 +32,8 @@ class ORB_control ~ORB_control(); + void Init(Input &inp, UnitCell_pseudo &ucell); + //first step: read orbital file void read_orb_first( std::ofstream &ofs_in, @@ -73,6 +77,8 @@ class ORB_control Parallel_Orbitals ParaV; + bool setup_2d = false; + private: const bool gamma_only = 1; const int nlocal = 0; diff --git a/source/run_lcao.cpp b/source/run_lcao.cpp index 4d005dfac14..12f9e6b5df1 100644 --- a/source/run_lcao.cpp +++ b/source/run_lcao.cpp @@ -17,7 +17,8 @@ void Run_lcao::lcao_line(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_lcao","lcao_line"); ModuleBase::timer::tick("Run_lcao", "lcao_line"); - //------------------------------------init Cell---------------------------- + + //-----------------------init Cell-------------------------- // Setup the unitcell. // improvement: a) separating the first reading of the atom_card and subsequent // cell relaxation. b) put GlobalV::NLOCAL and GlobalV::NBANDS as input parameters @@ -51,7 +52,7 @@ void Run_lcao::lcao_line(ModuleEnSover::En_Solver *p_ensolver) Cal_Test::test_memory(); ModuleBase::QUIT(); } - //------------------------------------\init Cell---------------------------- + //-----------------------init Cell-------------------------- //------------------------------------------------------------ @@ -59,7 +60,8 @@ void Run_lcao::lcao_line(ModuleEnSover::En_Solver *p_ensolver) p_ensolver->Init(INPUT, GlobalC::ucell); //------------------------------------------------------------ - //------------------------------------init Basis_lcao---------------------------- + //------------------init Basis_lcao---------------------- + // Init Basis should be put outside of Ensolver. // * reading the localized orbitals/projectors // * construct the interpolation tables. ORB_control orb_con( @@ -69,44 +71,12 @@ void Run_lcao::lcao_line(ModuleEnSover::En_Solver *p_ensolver) GlobalV::NB2D, GlobalV::DCOLOR, GlobalV::DRANK, GlobalV::MY_RANK, GlobalV::CALCULATION, GlobalV::KS_SOLVER); + + Init_Basis_lcao(orb_con, INPUT, GlobalC::ucell); + //------------------init Basis_lcao---------------------- - orb_con.read_orb_first( - GlobalV::ofs_running, - GlobalC::ORB, - GlobalC::ucell.ntype, - GlobalC::ucell.lmax, - INPUT.lcao_ecut, - INPUT.lcao_dk, - INPUT.lcao_dr, - INPUT.lcao_rmax, - GlobalV::out_descriptor, - INPUT.out_r_matrix, - GlobalV::FORCE, - GlobalV::MY_RANK); - - GlobalC::ucell.infoNL.setupNonlocal( - GlobalC::ucell.ntype, - GlobalC::ucell.atoms, - GlobalV::ofs_running, - GlobalC::ORB - ); - -#ifdef __MPI - orb_con.set_orb_tables( - GlobalV::ofs_running, - GlobalC::UOT, - GlobalC::ORB, - GlobalC::ucell.lat0, - GlobalV::out_descriptor, - Exx_Abfs::Lmax, - GlobalC::ucell.infoNL.nprojmax, - GlobalC::ucell.infoNL.nproj, - GlobalC::ucell.infoNL.Beta); -#endif - - orb_con.setup_2d_division(GlobalV::ofs_running, GlobalV::ofs_warning); - //------------------------------------\init Basis_lcao---------------------------- + //---------------------------MD/Relax------------------ if (GlobalV::CALCULATION == "md") { Run_MD_LCAO run_md_lcao(orb_con.ParaV); @@ -117,8 +87,50 @@ void Run_lcao::lcao_line(ModuleEnSover::En_Solver *p_ensolver) LOOP_cell lc(orb_con.ParaV); //keep wfc_gamma or wfc_k remaining lc.opt_cell(orb_con, p_ensolver); - } + } + //---------------------------MD/Relax------------------ ModuleBase::timer::tick("Run_lcao","lcao_line"); return; } + +void Run_lcao::Init_Basis_lcao(ORB_control& orb_con, Input& inp, UnitCell_pseudo& ucell) +{ + // * reading the localized orbitals/projectors + // * construct the interpolation tables. + orb_con.read_orb_first( + GlobalV::ofs_running, + GlobalC::ORB, + ucell.ntype, + ucell.lmax, + inp.lcao_ecut, + inp.lcao_dk, + inp.lcao_dr, + inp.lcao_rmax, + GlobalV::out_descriptor, + inp.out_r_matrix, + GlobalV::FORCE, + GlobalV::MY_RANK); + + ucell.infoNL.setupNonlocal( + ucell.ntype, + ucell.atoms, + GlobalV::ofs_running, + GlobalC::ORB); + +#ifdef __MPI + orb_con.set_orb_tables( + GlobalV::ofs_running, + GlobalC::UOT, + GlobalC::ORB, + ucell.lat0, + GlobalV::out_descriptor, + Exx_Abfs::Lmax, + ucell.infoNL.nprojmax, + ucell.infoNL.nproj, + ucell.infoNL.Beta); +#endif + + if (orb_con.setup_2d) + orb_con.setup_2d_division(GlobalV::ofs_running, GlobalV::ofs_warning); +} \ No newline at end of file diff --git a/source/run_lcao.h b/source/run_lcao.h index 441bb5e68a8..bcc4c0d91ce 100644 --- a/source/run_lcao.h +++ b/source/run_lcao.h @@ -22,6 +22,8 @@ class Run_lcao // perform Linear Combination of Atomic Orbitals (LCAO) calculations static void lcao_line(ModuleEnSover::En_Solver *p_ensolver); +private: + static void Init_Basis_lcao(ORB_control& orb_con, Input& inp, UnitCell_pseudo& ucell); }; #endif