diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 41318c963ec..e7cdd42dc31 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,7 +8,7 @@ on: - ABACUS_2.2.0_beta - deepks - planewave - - TDDFT + - pw_refactor jobs: test: diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d4415a82d4..8716afbc8c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -272,6 +272,7 @@ target_link_libraries(${ABACUS_BIN_NAME} pw ri driver + en_solver -lm ) diff --git a/doc/input-main.md b/doc/input-main.md index 654ce8714de..8e87f6175b5 100644 --- a/doc/input-main.md +++ b/doc/input-main.md @@ -992,7 +992,7 @@ This part of variables are relevant when using hybrid functionals - exx_hybrid_type - *Type*: String - - *Description*: Type of hybrid functional used. Options are "hf" (pure Hartree-Fock), "pbe0"(PBE0), "hse" (Note: in order to use HSE functional, LIBXC is required). + - *Description*: Type of hybrid functional used. Options are "hf" (pure Hartree-Fock), "pbe0"(PBE0), "hse" (Note: in order to use HSE functional, LIBXC is required). Note also that HSE has been tested while PBE0 has NOT been fully tested yet, and the maxmum parallel cpus for running exx is Nx(N+1)/2, with N being the number of atoms. If set to "no", then no hybrid functional is used (i.e.,Fock exchange is not included.) @@ -1023,7 +1023,7 @@ adial integration for pseudopotentials, in Bohr. - exx_pca_threshold - *Type*: Real - - *Description*: To accelerate the evaluation of four-center integrals (ik|jl), the product of atomic orbitals are expanded in the basis of auxiliary basis functions (ABF): φiφj~CkijPk. The size of the ABF (i.e. number of Pk) is reduced using principal component analysis. When a large PCA threshold is used, the number of ABF will be reduced, hence the calculations becomes faster. However this comes at the cost of computational accuracy. A relatively safe choice of the value is 1d-3. + - *Description*: To accelerate the evaluation of four-center integrals (ik|jl), the product of atomic orbitals are expanded in the basis of auxiliary basis functions (ABF): φiφj~CkijPk. The size of the ABF (i.e. number of Pk) is reduced using principal component analysis. When a large PCA threshold is used, the number of ABF will be reduced, hence the calculations becomes faster. However this comes at the cost of computational accuracy. A relatively safe choice of the value is 1d-4. - *Default*: 0 [back to top](#input-file) @@ -1044,21 +1044,21 @@ adial integration for pseudopotentials, in Bohr. - exx_dm_threshold - *Type*: Real - - *Description*: The Fock exchange can be expressed as Σk,l(ik|jl)Dkl where D is the density matrix. Smaller values of the density matrix can be truncated to accelerate calculation. The larger the threshold is, the faster the calculation and the lower the accuracy. A relatively safe choice of the value is 1d-3. + - *Description*: The Fock exchange can be expressed as Σk,l(ik|jl)Dkl where D is the density matrix. Smaller values of the density matrix can be truncated to accelerate calculation. The larger the threshold is, the faster the calculation and the lower the accuracy. A relatively safe choice of the value is 1d-4. - *Default*: 0 [back to top](#input-file) - exx_schwarz_threshold - *Type*: Real - - *Description*: In practice the four-center integrals are sparse, and using Cauchy-Schwartz inequality, we can find an upper bound of each integral before carrying out explicit evaluations. Those that are smaller than exx_schwarz_threshold will be truncated. The larger the threshold is, the faster the calculation and the lower the accuracy. A relatively safe choice of the value is 1d-4. + - *Description*: In practice the four-center integrals are sparse, and using Cauchy-Schwartz inequality, we can find an upper bound of each integral before carrying out explicit evaluations. Those that are smaller than exx_schwarz_threshold will be truncated. The larger the threshold is, the faster the calculation and the lower the accuracy. A relatively safe choice of the value is 1d-5. - *Default*: 0 [back to top](#input-file) - exx_cauchy_threshold - *Type*: Real - - *Description*: In practice the Fock exchange matrix is sparse, and using Cauchy-Schwartz inequality, we can find an upper bound of each matrix element before carrying out explicit evaluations. Those that are smaller than exx_cauchy_threshold will be truncated. The larger the threshold is, the faster the calculation and the lower the accuracy. A relatively safe choice of the value is 1d-6. + - *Description*: In practice the Fock exchange matrix is sparse, and using Cauchy-Schwartz inequality, we can find an upper bound of each matrix element before carrying out explicit evaluations. Those that are smaller than exx_cauchy_threshold will be truncated. The larger the threshold is, the faster the calculation and the lower the accuracy. A relatively safe choice of the value is 1d-7. - *Default*: 0 [back to top](#input-file) diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index e65955d5989..fb72efffb62 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(module_neighbor) add_subdirectory(module_orbital) add_subdirectory(module_md) add_subdirectory(module_deepks) +add_subdirectory(module_ensolver) add_subdirectory(src_io) add_subdirectory(src_ions) add_subdirectory(src_lcao) diff --git a/source/Makefile b/source/Makefile index b3bcc1d2218..eab82838029 100644 --- a/source/Makefile +++ b/source/Makefile @@ -25,6 +25,8 @@ VPATH=./src_global\ :./src_ri\ :./\ +include module_ensolver/Makefile.ensolver + #========================== # Define HONG #========================== diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 44a72c0d0e2..f486266f201 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -176,7 +176,6 @@ FORCE_k.o\ parallel_orbitals.o \ global_fp.o \ pdiag_double.o \ -pdiag_basic.o \ pdiag_common.o \ diag_scalapack_gvx.o \ subgrid_oper.o \ diff --git a/source/driver.cpp b/source/driver.cpp index 6128759c1d9..42412311a2d 100644 --- a/source/driver.cpp +++ b/source/driver.cpp @@ -85,22 +85,31 @@ void Driver::reading(void) void Driver::atomic_world(void) { ModuleBase::TITLE("Driver","atomic_world"); - //-------------------------------------------------- // choose basis sets: // pw: plane wave basis set // lcao_in_pw: LCAO expaned by plane wave basis set // lcao: linear combination of atomic orbitals //-------------------------------------------------- + string use_ensol; + ModuleEnSover::En_Solver *p_ensolver; if(GlobalV::BASIS_TYPE=="pw" || GlobalV::BASIS_TYPE=="lcao_in_pw") { - Run_pw::plane_wave_line(); + use_ensol = "ksdft_pw"; + //We set it temporarily + //Finally, we have ksdft_pw, ksdft_lcao, sdft_pw, ofdft, lj, eam, etc. + ModuleEnSover::init_esolver(p_ensolver, use_ensol); + Run_pw::plane_wave_line(p_ensolver); + ModuleEnSover::clean_esolver(p_ensolver); } #ifdef __LCAO else if(GlobalV::BASIS_TYPE=="lcao") - { - Run_lcao::lcao_line(); - } + { + use_ensol = "ksdft_lcao"; + ModuleEnSover::init_esolver(p_ensolver, use_ensol); + Run_lcao::lcao_line(p_ensolver); + ModuleEnSover::clean_esolver(p_ensolver); + } #endif ModuleBase::timer::finish( GlobalV::ofs_running ); diff --git a/source/module_base/intarray.h b/source/module_base/intarray.h index f8db1250753..64cacf014b7 100644 --- a/source/module_base/intarray.h +++ b/source/module_base/intarray.h @@ -5,10 +5,10 @@ #ifndef INTARRAY_H #define INTARRAY_H -#include +#include #include #include -#include +#include #ifdef _MCD_CHECK //#include "./src_parallel/mcd.h" @@ -16,63 +16,141 @@ namespace ModuleBase { - +/** + * @brief Integer array + * + */ class IntArray { -public: - int *ptr; - - // Constructors for different dimesnions - IntArray(const int d1 = 1, const int d2 = 1); - IntArray(const int d1, const int d2,const int d3); - IntArray(const int d1, const int d2,const int d3,const int d4); - IntArray(const int d1, const int d2,const int d3,const int d4,const int d5); - IntArray(const int d1, const int d2,const int d3,const int d4,const int d5,const int d6); - - ~IntArray(); - - void create(const int d1, const int d2); - void create(const int d1, const int d2, const int d3); - void create(const int d1, const int d2, const int d3, const int d4); - void create(const int d1, const int d2, const int d3, const int d4, const int d5); - void create(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6); - - const IntArray &operator=(const IntArray &right); - const IntArray &operator=(const int &right); - - int &operator()(const int d1, const int d2); - int &operator()(const int d1, const int d2, const int d3); - int &operator()(const int d1, const int d2, const int d3,const int d4); - int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5); - int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6); - - const int &operator()(const int d1,const int d2)const; - const int &operator()(const int d1,const int d2,const int d3)const; - const int &operator()(const int d1,const int d2,const int d3,const int d4)const; - const int &operator()(const int d1,const int d2,const int d3,const int d4, const int d5)const; - const int &operator()(const int d1,const int d2,const int d3,const int d4, const int d5, const int d6)const; - - void zero_out(void); - - int getSize() const{ return size;} - int getDim() const{ return dim;} - int getBound1() const{ return bound1;} - int getBound2() const{ return bound2;} - int getBound3() const{ return bound3;} - int getBound4() const { return bound4;} - int getBound5() const { return bound5;} - int getBound6() const { return bound6;} - - static int getArrayCount(void) - { return arrayCount;} - -private: - int size; - int dim; - int bound1, bound2, bound3, bound4, bound5, bound6; - static int arrayCount; - void freemem(); + public: + int *ptr; + + /** + * @brief Construct a new Int Array object + * + * @param d1 The first dimension size + * @param d2 The second dimension size + */ + IntArray(const int d1 = 1, const int d2 = 1); + IntArray(const int d1, const int d2, const int d3); + IntArray(const int d1, const int d2, const int d3, const int d4); + IntArray(const int d1, const int d2, const int d3, const int d4, const int d5); + IntArray(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6); + + ~IntArray(); + + /** + * @brief Create integer arrays + * + * @param[in] d1 + * @param[in] d2 + */ + void create(const int d1, const int d2); + void create(const int d1, const int d2, const int d3); + void create(const int d1, const int d2, const int d3, const int d4); + void create(const int d1, const int d2, const int d3, const int d4, const int d5); + void create(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6); + + /** + * @brief Equal an IntArray to another one + * + * @param right + * @return const IntArray& + */ + const IntArray &operator=(const IntArray &right); + + /** + * @brief Equal all elements of an IntArray to an + * integer + * + * @param right + * @return const IntArray& + */ + const IntArray &operator=(const int &right); + + /** + * @brief Access elements by using operator "()" + * + * @param d1 + * @param d2 + * @return int& + */ + int &operator()(const int d1, const int d2); + int &operator()(const int d1, const int d2, const int d3); + int &operator()(const int d1, const int d2, const int d3, const int d4); + int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5); + int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6); + + /** + * @brief Access elements by using "()" through pointer + * without changing its elements + * + * @param d1 + * @param d2 + * @return const int& + */ + const int &operator()(const int d1, const int d2) const; + const int &operator()(const int d1, const int d2, const int d3) const; + const int &operator()(const int d1, const int d2, const int d3, const int d4) const; + const int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5) const; + const int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6) const; + + /** + * @brief Set all elements of an IntArray to zero + * + */ + void zero_out(void); + + int getSize() const + { + return size; + } + int getDim() const + { + return dim; + } + int getBound1() const + { + return bound1; + } + int getBound2() const + { + return bound2; + } + int getBound3() const + { + return bound3; + } + int getBound4() const + { + return bound4; + } + int getBound5() const + { + return bound5; + } + int getBound6() const + { + return bound6; + } + + /** + * @brief Get the Array Count object + * + * @return int + */ + static int getArrayCount(void) + { + return arrayCount; + } + + private: + int size; + int dim; + int bound1, bound2, bound3, bound4, bound5, bound6; + static int arrayCount; + void freemem(); }; -} +} // namespace ModuleBase -#endif // IntArray class +#endif // IntArray class diff --git a/source/module_base/math_bspline.cpp b/source/module_base/math_bspline.cpp index 378fcd1ed5e..63d95e3f97b 100644 --- a/source/module_base/math_bspline.cpp +++ b/source/module_base/math_bspline.cpp @@ -32,18 +32,12 @@ namespace ModuleBase } } - void Bspline::cleanp() - { - delete[] bezier; - bezier = NULL; - } - double Bspline::bezier_ele(int n) { return this->bezier[n]; } - void Bspline::getbslpine(double x) + void Bspline::getbspline(double x) { bezier[0] = 1.0; for(int k = 1 ; k <= norder ; ++k) diff --git a/source/module_base/math_bspline.h b/source/module_base/math_bspline.h index 2eb72d8880e..86155bb79dc 100644 --- a/source/module_base/math_bspline.h +++ b/source/module_base/math_bspline.h @@ -4,57 +4,54 @@ namespace ModuleBase { - -// -//DESCRIPTION: -// A class to treat Cardinal B-spline interpolation. -// qianrui created 2021-09-14 -//MATH: -// Only uniform nodes are considered: xm-x[m-1]=Dx(>= 0) for control node: X={x0,x1,...,xm}; -// Any function p(x) can be written by -// p(x)=\sum_i{ci*M_ik(x)} (k->infinity), -// where M_ik is the i-th k-order Cardinal B-spline base function -// and ci is undetermined coefficient. -// M_i0 = H(x-xi)-H(x-x[i+1]), H(x): step function -// x-xi x[i+k+1]-x -// M_ik(x)= ---------*M_i(k-1)(x)+ ----------------*M_[i+1][k-1](x) ( xi <= x <= x[i+1] ) -// x[i+k]-xi x[i+k+1]-x[i+1] -// For uniform nodes: M_[i+1]k(x+Dx)=M_ik(x) -// If we define Bk[n] stores M_ik(x+n*Dx) for x in (xi,xi+Dx): -// x+n*Dx-xi xi+(k-n+1)*Dx-x -// Bk[n] = -----------*B(k-1)[n] + -----------------*B(k-1)[n-1] -// k*Dx k*Dx -//USAGE: -// ModuleBase::Bspline bp; -// bp.init(10,0.7,2); //Dx = 0.7, xi = 2 -// bp.getbslpine(0.5); //x = 0.5 -// cout<= 0) for control node: X={x0,x1,...,xm}; + * Any function p(x) can be written by + * p(x)=\sum_i{ci*M_ik(x)} (k->infinity), + * where M_ik is the i-th k-order Cardinal B-spline base function + * and ci is undetermined coefficient. + * M_i0 = H(x-xi)-H(x-x[i+1]), H(x): step function + * x-xi x[i+k+1]-x + * M_ik(x)= ---------*M_i(k-1)(x)+ ----------------*M_[i+1][k-1](x) ( xi <= x <= x[i+1] ) + * x[i+k]-xi x[i+k+1]-x[i+1] + * For uniform nodes: M_[i+1]k(x+Dx)=M_ik(x) + * If we define Bk[n] stores M_ik(x+n*Dx) for x in (xi,xi+Dx): + * x+n*Dx-xi xi+(k-n+1)*Dx-x + * Bk[n] = -----------*B(k-1)[n] + -----------------*B(k-1)[n-1] + * k*Dx k*Dx + * USAGE: + * ModuleBase::Bspline bp; + * bp.init(10,0.7,2); //Dx = 0.7, xi = 2 + * bp.getbslpine(0.5); //x = 0.5 + * cout<= 0 - double Dx; //Dx: the interval of control node - double xi; // xi: the starting point - double *bezier; //bezier[n] = Bk[n] + private: + int norder; // the order of bezier base; norder >= 0 + double Dx; // Dx: the interval of control node + double xi; // xi: the starting point + double *bezier; // bezier[n] = Bk[n] - public: - Bspline(); - ~Bspline(); + public: + Bspline(); + ~Bspline(); - //Init norder, Dx, xi - void init(int norderin, double Dxin, double xiin); + void init(int norderin, double Dxin, double xiin); - //delete[] bezier - void cleanp(); + // Get the result of i-th bezier base functions for different input x+xi+n*Dx. + // x should be in [0,Dx] + // n-th result is stored in bezier[n]; + void getbspline(double x); - //Get the result of i-th bezier base functions for different input x+xi+n*Dx. - //x should be in [0,Dx] - //n-th result is stored in bezier[n]; - void getbslpine(double x); - - //get the element of bezier - double bezier_ele(int n); + // get the element of bezier + double bezier_ele(int n); }; -} +} // namespace ModuleBase #endif diff --git a/source/module_base/math_polyint.h b/source/module_base/math_polyint.h index 0cc13f8f831..b15f50678c3 100644 --- a/source/module_base/math_polyint.h +++ b/source/module_base/math_polyint.h @@ -18,6 +18,19 @@ class PolyInt //======================================================== // Polynomial_Interpolation //======================================================== + + /** + * @brief Lagrange interpolation + * + * @param table [in] three dimension matrix, the data in 3rd dimension is used to do prediction + * @param dim1 [in] index of 1st dimension of table/y + * @param dim2 [in] index of 2nd dimension of table/y + * @param y [out] three dimension matrix to store the predicted value + * @param dim_y [in] index of 3rd dimension of y to store predicted value + * @param table_length [in] length of 3rd dimension of table + * @param table_interval [in] interval of 3rd dimension of table + * @param x [in] the position in 3rd dimension to be predicted + */ static void Polynomial_Interpolation ( const ModuleBase::realArray &table, @@ -30,6 +43,17 @@ class PolyInt const double &x ); + /** + * @brief Lagrange interpolation + * + * @param table [in] three dimension matrix, the data in 3rd dimension is used to do prediction + * @param dim1 [in] index of 1st dimension of table + * @param dim2 [in] index of 2nd dimension of table + * @param table_length [in] length of 3rd dimension of table + * @param table_interval [in] interval of 3rd dimension of table + * @param x [in] the position in 3rd dimension to be predicted + * @return double the predicted value + */ static double Polynomial_Interpolation ( const ModuleBase::realArray &table, @@ -37,10 +61,24 @@ class PolyInt const int &dim2, const int &table_length, const double &table_interval, - const double &x // input value + const double &x ); - static double Polynomial_Interpolation // pengfei Li 2018-3-23 + /** + * @brief Lagrange interpolation + * + * @param table [in] four dimension matrix, the data in 4th dimension is used to do prediction + * @param dim1 [in] index of 1st dimension of table + * @param dim2 [in] index of 2nd dimension of table + * @param dim3 [in] index of 3rd dimension of table + * @param table_length [in] length of 4th dimension of table + * @param table_interval [in] interval of 4th dimension of table + * @param x [in] the position in 4th dimension to be predicted + * @return double the predicted value + * @author pengfei Li + * @date 2018-3-23 + */ + static double Polynomial_Interpolation ( const ModuleBase::realArray &table, const int &dim1, @@ -48,23 +86,41 @@ class PolyInt const int &dim3, const int &table_length, const double &table_interval, - const double &x // input value + const double &x ); + /** + * @brief Lagrange interpolation + * + * @param table [in] the data used to do prediction + * @param table_length [in] length of table + * @param table_interval [in] interval of table + * @param x [in] the position to be predicted + * @return double the predicted value + */ static double Polynomial_Interpolation ( const double *table, const int &table_length, const double &table_interval, - const double &x // input value + const double &x ); + /** + * @brief Lagrange interpolation + * + * @param xpoint [in] array of postion + * @param ypoint [in] array of data to do prediction + * @param table_length [in] length of xpoint + * @param x [in] position to be predicted + * @return double predicted value + */ static double Polynomial_Interpolation_xy ( const double *xpoint, const double *ypoint, const int table_length, - const double &x // input value + const double &x ); }; diff --git a/source/module_base/math_sphbes.h b/source/module_base/math_sphbes.h index f152e52e659..f7257b2dc3e 100644 --- a/source/module_base/math_sphbes.h +++ b/source/module_base/math_sphbes.h @@ -14,26 +14,53 @@ class Sphbes Sphbes(); ~Sphbes(); + /** + * @brief spherical bessel + * + * @param msh [in] number of grid points + * @param r [in] radial grid + * @param q [in] k_radial + * @param l [in] angular momentum + * @param jl [out] jl spherical bessel function + */ static void Spherical_Bessel ( - const int &msh, //number of grid points - const double *r,//radial grid - const double &q, // - const int &l, //angular momentum - double *jl //jl(1:msh) = j_l(q*r(i)),spherical bessel function + const int &msh, + const double *r, + const double &q, + const int &l, + double *jl ); + /** + * @brief spherical bessel + * + * @param msh [in] number of grid points + * @param r [in] radial grid + * @param q [in] k_radial + * @param l [in] angular momentum + * @param jl [out] jl spherical bessel function + * @param sjp [out] sjp[i] is assigned to be 1.0. i < msh. + */ static void Spherical_Bessel ( - const int &msh, //number of grid points - const double *r,//radial grid - const double &q, // - const int &l, //angular momentum - double *sj, //jl(1:msh) = j_l(q*r(i)),spherical bessel function + const int &msh, + const double *r, + const double &q, + const int &l, + double *sj, double *sjp ); - + /** + * @brief return num eigenvalues of spherical bessel function + * + * @param num [in] the number of eigenvalues + * @param l [in] angular number + * @param epsilon [in] the accuracy + * @param eigenvalue [out] the calculated eigenvalues + * @param rcut [in] the cutoff the radial function + */ static void Spherical_Bessel_Roots ( const int &num, diff --git a/source/module_base/math_ylmreal.h b/source/module_base/math_ylmreal.h index 362f2dfdd82..c051ef95ca3 100644 --- a/source/module_base/math_ylmreal.h +++ b/source/module_base/math_ylmreal.h @@ -14,29 +14,54 @@ class YlmReal YlmReal(); ~YlmReal(); + /** + * @brief spherical harmonic function (real form) an array of vectors + * + * @param lmax2 [in] lmax2 = (lmax + 1)^2 ; lmax = angular quantum number + * @param ng [in] the number of vectors + * @param g [in] an array of vectors + * @param ylm [out] Ylm; column index represent vector, row index represent Y00, Y10, Y11, Y1-1, Y20,Y21,Y2-1,Y22.Y2-2,...; + */ static void Ylm_Real ( - const int lmax2, // lmax2 = (lmax+1)^2 - const int ng, // - const ModuleBase::Vector3 *g, // g_cartesian_vec(x,y,z) - matrix &ylm // output + const int lmax2, + const int ng, + const ModuleBase::Vector3 *g, + matrix &ylm ); + /** + * @brief spherical harmonic function (Herglotz generating form) of an array of vectors + * + * @param lmax2 [in] lmax2 = (lmax + 1)^2 ; lmax = angular quantum number + * @param ng [in] the number of vectors + * @param g [in] an array of vectors + * @param ylm [out] Ylm; column index represent vector, row index represent Y00, Y10, Y11, Y1-1, Y20,Y21,Y2-1,Y22.Y2-2,...; + */ static void Ylm_Real2 ( - const int lmax2, // lmax2 = (lmax+1)^2 - const int ng, // - const ModuleBase::Vector3 *g, // g_cartesian_vec(x,y,z) - matrix &ylm // output + const int lmax2, + const int ng, + const ModuleBase::Vector3 *g, + matrix &ylm ); + /** + * @brief spherical harmonic function (Herglotz generating form) of a vector + * + * @param lmax [in] maximum angular quantum number + * @param x [in] x part of the vector + * @param y [in] y part of the vector + * @param z [in] z part of the vector + * @param rly [in] Ylm, Y00, Y10, Y11, Y1-1, Y20,Y21,Y2-1,Y22.Y2-2,... + */ static void rlylm ( const int lmax, const double& x, const double& y, - const double& z, // g_cartesian_vec(x,y,z) - double* rly // output + const double& z, + double* rly ); private: diff --git a/source/module_base/mathzone.h b/source/module_base/mathzone.h index 99622ac29a1..6e5180afd8c 100644 --- a/source/module_base/mathzone.h +++ b/source/module_base/mathzone.h @@ -1,69 +1,82 @@ #ifndef MATHZONE_H #define MATHZONE_H -#include "realarray.h" -#include "matrix3.h" #include "global_function.h" -#include -#include +#include "matrix3.h" +#include "realarray.h" + #include #include +#include +#include namespace ModuleBase { +/** + * @brief atomic coordinates conversion functions + * + */ class Mathzone { - public: - + public: Mathzone(); ~Mathzone(); - template - static T Max3(const T &a,const T &b,const T &c) + public: + /** + * @brief Pointwise product of two vectors with same size + * + * @tparam Type + * @param[in] f1 + * @param[in] f2 + * @return std::vector + * @author Peize Lin (2016-08-03) + */ + template + static std::vector Pointwise_Product(const std::vector &f1, const std::vector &f2) { - if (a>=b && a>=c) return a; - else if (b>=a && b>=c) return b; - else if (c>=a && c>=b) return c; - else throw std::runtime_error(ModuleBase::GlobalFunc::TO_STRING(__FILE__)+" line "+ModuleBase::GlobalFunc::TO_STRING(__LINE__)); + assert(f1.size() == f2.size()); + std::vector f(f1.size()); + for (int ir = 0; ir != f.size(); ++ir) + f[ir] = f1[ir] * f2[ir]; + return f; } - // be careful, this can only be used for plane wave - // during parallel calculation - -public: - - - - // Peize Lin add 2016-08-03 - template< typename Type > - static std::vector Pointwise_Product( const std::vector &f1, const std::vector &f2 ) - { - assert(f1.size()==f2.size()); - std::vector f(f1.size()); - for( int ir=0; ir!=f.size(); ++ir ) - f[ir] = f1[ir] * f2[ir]; - return f; - } - -//========================================================== -// MEMBER FUNCTION : -// NAME : Direct_to_Cartesian -// use lattice vector matrix R -// change the direct std::vector (dx,dy,dz) to cartesuab vectir -// (cx,cy,cz) -// (dx,dy,dz) = (cx,cy,cz) * R -// -// NAME : Cartesian_to_Direct -// the same as above -// (cx,cy,cz) = (dx,dy,dz) * R^(-1) -//========================================================== - static inline void Direct_to_Cartesian - ( - const double &dx,const double &dy,const double &dz, - const double &R11,const double &R12,const double &R13, - const double &R21,const double &R22,const double &R23, - const double &R31,const double &R32,const double &R33, - double &cx,double &cy,double &cz) + /** + * @brief change direct coordinate (dx,dy,dz) to + * Cartesian coordinate (cx,cy,cz), (dx,dy,dz) = (cx,cy,cz) * R + * + * @param[in] dx Direct coordinats + * @param[in] dy + * @param[in] dz + * @param[in] R11 Lattice vector matrix R_ij: i_row, j_column + * @param[in] R12 + * @param[in] R13 + * @param[in] R21 + * @param[in] R22 + * @param[in] R23 + * @param[in] R31 + * @param[in] R32 + * @param[in] R33 + * @param[out] cx Cartesian coordinats + * @param[out] cy + * @param[out] cz + */ + static inline void Direct_to_Cartesian(const double &dx, + const double &dy, + const double &dz, + const double &R11, + const double &R12, + const double &R13, + const double &R21, + const double &R22, + const double &R23, + const double &R31, + const double &R32, + const double &R33, + double &cx, + double &cy, + double &cz) { static ModuleBase::Matrix3 lattice_vector; static ModuleBase::Vector3 direct_vec, cartesian_vec; @@ -88,13 +101,41 @@ class Mathzone return; } - static inline void Cartesian_to_Direct - ( - const double &cx,const double &cy,const double &cz, - const double &R11,const double &R12,const double &R13, - const double &R21,const double &R22,const double &R23, - const double &R31,const double &R32,const double &R33, - double &dx,double &dy,double &dz) + /** + * @brief Change Cartesian coordinate (cx,cy,cz) to + * direct coordinate (dx,dy,dz), (cx,cy,cz) = (dx,dy,dz) * R^(-1) + * + * @param[in] cx Cartesian coordinats + * @param[in] cy + * @param[in] cz + * @param[in] R11 Lattice vector matrix R_ij: i_row, j_column + * @param[in] R12 + * @param[in] R13 + * @param[in] R21 + * @param[in] R22 + * @param[in] R23 + * @param[in] R31 + * @param[in] R32 + * @param[in] R33 + * @param[out] dx Direct coordinats + * @param[out] dy + * @param[out] dz + */ + static inline void Cartesian_to_Direct(const double &cx, + const double &cy, + const double &cz, + const double &R11, + const double &R12, + const double &R13, + const double &R21, + const double &R22, + const double &R23, + const double &R31, + const double &R32, + const double &R33, + double &dx, + double &dy, + double &dz) { static ModuleBase::Matrix3 lattice_vector, inv_lat; lattice_vector.e11 = R11; @@ -120,45 +161,17 @@ class Mathzone return; } - - static void To_Polar_Coordinate + void To_Polar_Coordinate ( - const double &x_cartesian, - const double &y_cartesian, - const double &z_cartesian, - double &r, - double &theta, - double &phi - ); - - - // coeff1 * x1 + (1-coeff1) * x2 - // Peize Lin add 2017-08-09 - template< typename T, typename T_coeff > - static T Linear_Mixing( const T & x1, const T & x2, const T_coeff & coeff1 ) - { - return coeff1 * x1 + (1-coeff1) * x2; - } - template< typename T, typename T_coeff > - static std::vector Linear_Mixing( const std::vector & x1, const std::vector & x2, const T_coeff & coeff1 ) - { - assert(x1.size()==x2.size()); - std::vector x; - for( size_t i=0; i!=x1.size(); ++i ) - x.push_back( Linear_Mixing( x1[i], x2[i], coeff1 ) ); - return x; - } - template< typename T1, typename T2, typename T_coeff > - static std::map Linear_Mixing( const std::map & x1, const std::map & x2, const T_coeff & coeff1 ) - { - std::map x; - for( const auto & x1i : x1 ) - x.insert( make_pair( x1i.first, Linear_Mixing( x1i.second, x2.at(x1i.first), coeff1 ) ) ); - return x; - } + const double &x_cartesian, + const double &y_cartesian, + const double &z_cartesian, + double &r, + double &theta, + double &phi); }; -} +} // namespace ModuleBase #endif diff --git a/source/module_base/mathzone_add1.cpp b/source/module_base/mathzone_add1.cpp index 459f0205b11..f22667c104d 100644 --- a/source/module_base/mathzone_add1.cpp +++ b/source/module_base/mathzone_add1.cpp @@ -28,7 +28,6 @@ typedef fftw_complex FFTW_COMPLEX; namespace ModuleBase { -bool Mathzone_Add1::flag_jlx_expand_coef = false; double** Mathzone_Add1::c_ln_c = nullptr; double** Mathzone_Add1::c_ln_s = nullptr; @@ -39,605 +38,6 @@ Mathzone_Add1::~Mathzone_Add1() {} -/********************************************** - * coefficients to expand jlx using - * cos and sin, from SIESTA - * *******************************************/ -void Mathzone_Add1::expand_coef_jlx() -{ - ModuleBase::timer::tick("Mathzone_Add1","expand_coef_jlx"); - - int ir, il, in; - const int L = 20; - - //allocation - delete[] c_ln_c; - delete[] c_ln_s; - - c_ln_c = new double*[L+1]; - c_ln_s = new double*[L+1]; - - for(ir = 0; ir < L+1; ir++) - { - c_ln_c[ir] = new double[ir+1]; - c_ln_s[ir] = new double[ir+1]; - } - - //calculate initial value - c_ln_c[0][0] = 0.0; - c_ln_s[0][0] = 1.0; - - c_ln_c[1][0] = 0.0; - c_ln_c[1][1] = -1.0; - c_ln_s[1][0] = 1.0; - c_ln_s[1][1] = 0.0; - - //recursive equation - for(il = 2; il < Mathzone_Add1::sph_lmax+1; il++) - { - for(in = 0; in < il + 1; in++) - { - if(in >= 2) - { - if(in == il) - { - c_ln_c[il][in] = -c_ln_c[il-2][in-2]; - c_ln_s[il][in] = -c_ln_s[il-2][in-2]; - } - else - { - c_ln_c[il][in] = (2*il-1)*c_ln_c[il-1][in] - c_ln_c[il-2][in-2]; - c_ln_s[il][in] = (2*il-1)*c_ln_s[il-1][in] - c_ln_s[il-2][in-2]; - } - } - else - { - //in = 0 or 1, but here il = 2, so in != il - c_ln_c[il][in] = (2*il-1)*c_ln_c[il-1][in]; - c_ln_s[il][in] = (2*il-1)*c_ln_s[il-1][in]; - } - } - } - - ModuleBase::timer::tick("Mathzone_Add1","expand_coef_jlx"); - return; -} - -void Mathzone_Add1::Spherical_Bessel -( - const int &msh, //number of grid points - const double *r,//radial grid - const double &q, // - const int &l, //angular momentum - double *sj, //jl(1:msh) = j_l(q*r(i)),spherical bessel function - double *sjp -) -{ - ModuleBase::timer::tick ("Mathzone_Add1","Spherical_Bessel"); - - assert (l <= Mathzone_Add1::sph_lmax); - - //creat coefficients - if (!flag_jlx_expand_coef) - { - Mathzone_Add1::expand_coef_jlx (); - flag_jlx_expand_coef = true; - } - - //epsilon - const double eps = 1.0E-10; - - /****************************************************************** - jlx = \sum_{n=0}^{l}(c_{ln}^{s}sin(x) + c_{ln}^{c}cos(x))/x^{l-n+1} - jldx = \sum_{n=0}^{l}(c_{ln}^{s}sin(x) + c_{ln}^{c}cos(x))/x^{l-n+1} - *******************************************************************/ - for (int ir = 0; ir < msh; ir++) - { - double qr = q * r[ir]; - - //judge qr approx 0 - if (fabs(qr) < eps) - { - if (l == 0) sj[ir] = 1.0; - else sj[ir] = 0.0; - - if (l == 1) sjp[ir] = 1.0 / 3.0; - else sjp[ir] = 0.0; - } - else - { - sj[ir] = 0.0; - sjp[ir] = 0.0; - - double lqr = pow (qr, l); - - //divided fac - double xqr = 1.0; - - for (int n = 0; n <= l; n++) - { - double com1 = (c_ln_s[l][n] * sin(qr) + c_ln_c[l][n] * cos(qr)) * xqr; - double com2 = (c_ln_s[l][n] * cos(qr) - c_ln_c[l][n] * sin(qr)) * xqr * qr - + (c_ln_s[l][n] * sin(qr) + c_ln_c[l][n] * cos(qr)) * n * xqr; - - sj[ir] += com1; - sjp[ir] += (com2 - l*com1); - - xqr *= qr; - } - - sj[ir] /= lqr; - sj[ir] /= (lqr * qr); - } - } - - ModuleBase::timer::tick ("Mathzone_Add1","Spherical_Bessel"); - - return; - -} - -double Mathzone_Add1::uni_simpson -( - const double* func, - const int& mshr, - const double& dr -) -{ - ModuleBase::timer::tick("Mathzone_Add1","uni_simpson"); - - assert(mshr >= 3); - - int ir=0; - int idx=0; - int msh_left=0; - double sum = 0.0; - - //f(a) - sum += func[0]; - - //simpson 3/8 rule - if(mshr % 2 == 0) - { - //simpson 3/8 rule - sum += 9.0 / 8.0 * (func[mshr-1] + 3.0*(func[mshr-2] + func[mshr-3]) + func[mshr-4]); - msh_left = mshr - 3; - } - else - { - msh_left = mshr; - } - - //f(b) - sum += func[msh_left-1]; - - //points left - for(ir = 0; ir < msh_left / 2; ir++) - { - idx = 2*ir+1; - sum += 4.0 * func[idx]; - } - - for(ir = 1; ir < msh_left / 2; ir++) - { - idx = 2*ir; - sum += 2.0 * func[idx]; - } - - ModuleBase::timer::tick("Mathzone_Add1","uni_simpson"); - return sum * dr / 3.0; -} - -void Mathzone_Add1::uni_radfft -( - const int& n, - const int& aml, - const int& mshk, - const double* arr_k, - const int& mshr, - const double* ri, - const double& dr, - const double* phir, - double* phik -) -{ - ModuleBase::timer::tick ("Mathzone_Add1","uni_radfft"); - - //allocate memory - double* fi = new double[mshr]; - double* fi_cp = new double[mshr]; - double* jl = new double[mshr]; - - //function to be integrated: r^2 phir - for (int ir = 0; ir < mshr; ir++) - { - fi_cp[ir] = pow(ri[ir], n) * phir[ir]; - } - - //integration - for (int ik = 0; ik < mshk; ik++) - { - //calculate spherical bessel - ModuleBase::Sphbes::Spherical_Bessel(mshr, ri, arr_k[ik], aml, jl); - - //functions to be integrated - for (int ir = 0; ir < mshr; ir++) - { - fi[ir] = fi_cp[ir] * jl[ir]; - } - - phik[ik] = uni_simpson (fi, mshr, dr); - } - - //deallocate memory - delete[] fi; - delete[] fi_cp; - delete[] jl; - - ModuleBase::timer::tick("Mathzone_Add1","uni_radfft"); - - return; -} - -void Mathzone_Add1::Sph_Bes (double x, int lmax, double *sb, double *dsb) -{ - ModuleBase::timer::tick("Mathzone_Add1","Spherical_Bessel"); - - int m, n, nmax; - double j0, j1, sf, tmp, si, co, ix; // j0p, j1p, ix2 - - if (x < 0.0) - { - std::cout << "\nminus x is invalid for Spherical_Bessel" << std::endl; - exit(0); // mohan add 2021-05-06 - } - - const double xmin = 1E-10; - - /* find an appropriate nmax */ - nmax = lmax + 3*static_cast(x) + 20; - if (nmax < 100) nmax = 100; - - /* allocate tsb */ - double* tsb = new double[nmax+1]; - - /* if x is larger than xmin */ - - if (xmin < x) - { - /* initial values*/ - tsb[nmax] = 0.0; - tsb[nmax-1] = 1.0e-14; - - /* downward recurrence from nmax-2 to lmax+2 */ - - for (n = nmax-1; (lmax+2) < n; n--) - { - tsb[n-1] = (2.0*n + 1.0)/x*tsb[n] - tsb[n+1]; - - if (1.0e+250 < tsb[n-1]) - { - tmp = tsb[n-1]; - tsb[n-1] /= tmp; - tsb[n ] /= tmp; - } - } - - /* downward recurrence from lmax+1 to 0 */ - n = lmax + 3; - tmp = tsb[n-1]; - tsb[n-1] /= tmp; - tsb[n ] /= tmp; - - for (n = lmax+2; 0 < n; n--) - { - tsb[n-1] = (2.0*n + 1.0)/x*tsb[n] - tsb[n+1]; - - if (1.0e+250 < tsb[n-1]) - { - tmp = tsb[n-1]; - for (m = n-1; m <= lmax+1; m++) - { - tsb[m] /= tmp; - } - } - } - - /* normalization */ - si = sin(x); - co = cos(x); - ix = 1.0/x; - //ix2 = ix*ix; - j0 = si*ix; - j1 = si*ix*ix - co*ix; - - if (fabs(tsb[1]) < fabs(tsb[0])) sf = j0/tsb[0]; - else sf = j1/tsb[1]; - - /* tsb to sb */ -// for (n = 0; n <= lmax+1; n++) - for (n = 0; n <= lmax; n++) - { - sb[n] = tsb[n]*sf; - } - - /* derivative of sb */ - dsb[0] = co*ix - si*ix*ix; - // for (n = 1; n <= lmax; n++) - for (n = 1; n < lmax; n++) - { - dsb[n] = ( (double)n*sb[n-1] - (double)(n+1.0)*sb[n+1] )/(2.0*(double)n + 1.0); - } - - n = lmax; - dsb[n] = ( (double)n*sb[n-1] - (double)(n+1.0)*sf*tsb[n+1] )/(2.0*(double)n + 1.0); - } - - /* if x is smaller than xmin */ - else - { - /* sb */ - for (n = 0; n <= lmax; n++ ) - { - sb[n] = 0.0; - } - sb[0] = 1.0; - - /* derivative of sb */ - dsb[0] = 0.0; - dsb[1] = 1.0 / 3.0; - for (n = 2; n <= lmax; n++) - { -// dsb[n] = ( (double)n*sb[n-1] - (double)(n+1.0)*sb[n+1] )/(2.0*(double)n + 1.0); - dsb[n] = 0.0; - } - } - - /* free tsb */ - delete [] tsb; - - ModuleBase::timer::tick("Mathzone_Add1","Spherical_Bessel"); - - return; -} - -void Mathzone_Add1::Sbt_new -( - const int& polint_order, - const int& l, - const double* k, - const double& dk, - const int& mshk, - const double* r, - const double& dr, - const int& mshr, - const double* fr, - const int& rpow, - double* fk -) -{ - ModuleBase::timer::tick ("Mathzone_Add1","Sbt_new"); - - //check parameter - assert (mshr >= 1); - assert (l >= 0); - assert (rpow >= 0 && rpow <=2); - - //step 0 - //l is odd or even - bool parity_flag; - if (l % 2 == 0) - { - parity_flag = true; - } - else - { - parity_flag = false; - } - - ModuleBase::GlobalFunc::ZEROS (fk, mshk); - - if (polint_order != 3) - { - std::cout << "\nhigh order interpolation is not available!" << std::endl; - exit(0); // mohan add 2021-05-06 - //ModuleBase::QUIT(); - } - - /********************************** - function multiplied by power of r - for different polint_order - **********************************/ - double* fr2; - double* fr3; - - //polint_order == 1 - fr2 = new double[mshr]; - if (rpow == 0) for (int ir = 0; ir < mshr; ir++) fr2[ir] = fr[ir]*r[ir]*r[ir]; - else if (rpow == 1) for (int ir = 0; ir < mshr; ir++) fr2[ir] = fr[ir]*r[ir]; - else if (rpow == 2) for (int ir = 0; ir < mshr; ir++) fr2[ir] = fr[ir]; - - fr3 = new double[mshr]; - for (int ir = 0; ir < mshr; ir++) fr3[ir] = fr2[ir] * r[ir]; - -// const int polint_order = 3; - int nu_pol_coef = (polint_order+1)*(mshk-1); - double* polint_coef = new double[nu_pol_coef]; - - //step 1 - //start calc - if (parity_flag) - { - //even - const int n = l/2; - - //coef for interpolation - int ct = 0; - - double ft_save = fourier_cosine_transform (fr2, r, mshr, dr, k[0]); - double dft_save = -fourier_sine_transform (fr3, r, mshr, dr, k[0]); - - for (int ik = 0; ik < mshk-1; ik++) - { - double ft0 = ft_save; - double dft0 = dft_save; - - double ft1 = fourier_cosine_transform (fr2, r, mshr, dr, k[ik+1]); - double dft1 = -fourier_sine_transform (fr3, r, mshr, dr, k[ik+1]); - - //double d2k = dk*dk; - //double d3k = d2k*dk; - - double c0 = ft0; - double c1 = dft0; - double c2 = 3.0*(ft1-ft0)/dk/dk-(dft1+2.0*dft0)/dk; - double c3 = (-2.0*(ft1-ft0)/dk+(dft1+dft0))/dk/dk; - - double k2 = k[ik]*k[ik]; - double k3 = k2*k[ik]; - - polint_coef[ct] = c0-c1*k[ik]+c2*k2-c3*k3; - polint_coef[ct+1] = c1-2.0*c2*k[ik]+3.0*c3*k2; - polint_coef[ct+2] = c2-3.0*k[ik]*c3; - polint_coef[ct+3] = c3; - - //test - // double x = (k[ik]+k[ik+1])/2; - // double x = k[ik]; -// double tmp = polint_coef[ct]+x*polint_coef[ct+1]+x*x*polint_coef[ct+2]+x*x*x*polint_coef[ct+3]; - // double tmp_ana = fourier_cosine_transform (fr2, r, mshr, dr, x); - // std::cout << "\ninterp = " << tmp << " ana = " << tmp_ana << " diff = " << log(fabs(tmp-tmp_ana))/log(10); - - //update - ct += (polint_order+1); - ft_save = ft1; - dft_save = dft1; - } - - //store coefficients for calculation - double* coef = new double[n+1]; - double fac = dualfac (l-1) / dualfac (l); - - for (int j = 0; j < n; j++) - { - coef[j] = fac; - - //update - int twoj = 2*j; - fac *= -static_cast((l+twoj+1)*(l-twoj))/(twoj+2)/(twoj+1); - } - coef[n] = pow (-1.0, n) * dualfac (2*l-1) / factorial (l); - - //start calc - //special case k = 0; - if (n ==0) fk[0] = uni_simpson (fr2, mshr, dr); - else fk[0] = 0.0; - - //k > 0 - for (int j =0; j <=n; j++) - { - //initialize Snm - double Snm = 0.0; - for (int ik = 1; ik < mshk; ik++) - { - Snm += pol_seg_int (polint_order, polint_coef, 2*j, k, ik); - double k2j = pow (k[ik], 2*j+1); - - fk[ik] += coef[j] * Snm / k2j; - } - } - //free - delete [] coef; - } - else - { - //odd - const int n = (l-1)/2; - - //coef for interpolation - int ct = 0; - double ft_save, dft_save; - ft_save = fourier_sine_transform (fr2, r, mshr, dr, k[0]); - dft_save = fourier_cosine_transform (fr3, r, mshr, dr, k[0]); - - for (int ik = 0; ik < mshk-1; ik++) - { - double ft0 = ft_save; - double dft0 = dft_save; - - double ft1 = fourier_sine_transform (fr2, r, mshr, dr, k[ik+1]); - double dft1 = fourier_cosine_transform (fr3, r, mshr, dr, k[ik+1]); - - double c0, c1, c2, c3; - c0 = ft0; - c1 = dft0; - c2 = 3.0*(ft1-ft0)/dk/dk-(dft1+2.0*dft0)/dk; - c3 = (-2.0*(ft1-ft0)/dk+(dft1+dft0))/dk/dk; - - double k2 = k[ik]*k[ik]; - double k3 = k2*k[ik]; - - polint_coef[ct] = c0-c1*k[ik]+c2*k2-c3*k3; - polint_coef[ct+1] = c1-2.0*c2*k[ik]+3.0*c3*k2; - polint_coef[ct+2] = c2-3.0*k[ik]*c3; - polint_coef[ct+3] = c3; - -// double x = (k[ik]+k[ik+1])/2; -// double x = k[ik]; -// double tmp = polint_coef[ct]+x*polint_coef[ct+1]+x*x*polint_coef[ct+2]+x*x*x*polint_coef[ct+3]; -// std::cout << "\ninterp = " << tmp << " ana = " << fourier_sine_transform (fr2, r, mshr, dr, x); - //update - ct += (polint_order+1); - ft_save = ft1; - dft_save = dft1; - } - - //store coefficients for calculation - double* coef = new double[n+1]; - double fac = dualfac (l) / dualfac (l-1); - - for (int j = 0; j < n; j++) - { - coef[j] = fac; - - //update - int twoj = 2*j; - fac *= -static_cast((l+twoj+2)*(l-twoj-1))/(twoj+3)/(twoj+2); - - //test -// std::cout << "\ncoef[j] = " << coef[j] << std::endl; - } - coef[n] = pow (-1.0, n) * dualfac (2*l-1) / factorial (l); - - //start calc - //special case k =0 ; - fk[0] = 0.0; - - //k > 0 - for (int j = 0; j <= n; j++) - { - double Snm = 0.0; - for (int ik = 1; ik < mshk; ik++) - { - Snm += pol_seg_int (polint_order, polint_coef, 2*j+1, k, ik); - double k2j = pow(k[ik], 2*j+2); - - fk[ik] += coef[j] * Snm / k2j; - } - } - - //free - delete [] coef; - } - - delete [] fr2; - delete [] fr3; - delete [] polint_coef; - - ModuleBase::timer::tick ("Mathzone_Add1","Sbt_new"); - return; -} - double Mathzone_Add1::factorial (const int& l) { if (l == 0 || l == 1) return 1.0; @@ -650,248 +50,6 @@ double Mathzone_Add1::dualfac (const int& l) else return l * dualfac (l-2); } -double Mathzone_Add1::pol_seg_int -( - const int& polint_order, - const double* coef, - const int& n, - const double* k, - const int& ik -) -{ - double val = 0.0; - double kmf = pow (k[ik], n+1); - double kmb = pow (k[ik-1], n+1); - - int cstart = (polint_order+1)*(ik-1); - for (int i = 0; i <= polint_order; i++) - { - val += coef[cstart+i]*(kmf-kmb)/(n+i+1); - /* - if (ik == 110) - { - std::cout << "i = " << i << " coef = " << coef[cstart+i] << " df = " << kmf-kmb << std::endl; - } - */ - //update - kmf *= k[ik]; - kmb *= k[ik-1]; - } - - /* - if (ik == 110) - { - std::cout << "val = " << val << std::endl; - ModuleBase::QUIT (); - } - */ - return val; -} - -double Mathzone_Add1::fourier_sine_transform -( - const double* func, - const double* r, - const int& mshr, - const double& dr, - const double& k -) -{ - ModuleBase::timer::tick ("Mathzone_Add1","Fsin"); - double val = 0.0; - double* sinf = new double[mshr]; - for (int ir = 0; ir < mshr; ir++) - { - sinf[ir] = func[ir] * sin(k*r[ir]); - } - val = uni_simpson (sinf, mshr, dr); - delete[] sinf; - ModuleBase::timer::tick ("Mathzone_Add1","Fsin"); - return val; -} - -double Mathzone_Add1::fourier_cosine_transform -( - const double* func, - const double* r, - const int& mshr, - const double& dr, - const double& k -) -{ - ModuleBase::timer::tick ("Mathzone_Add1","Fcos"); - double val = 0.0; - double* cosf = new double[mshr]; - for (int ir = 0; ir < mshr; ir++) - { - cosf[ir] = func[ir] * cos(k*r[ir]); - } - - val = uni_simpson (cosf, mshr, dr); - delete[] cosf; - ModuleBase::timer::tick ("Mathzone_Add1","Fcos"); - return val; -} - -void Mathzone_Add1::test () -{ - int polint_order =3; - int dim = 2048; - int ci = 1; - int l = 0; - double rmax = 20; - double dr = rmax/dim; - - double* rad = new double[dim]; - double* func = new double[dim]; - double* fk = new double[dim]; - for (int ir = 0; ir < dim; ir++) - { - rad[ir] = ir * dr; - func[ir] = pow(rad[ir], l) * exp(-ci*rad[ir]*rad[ir]); - fk[ir] = 0.0; - } - - Sbt_new (polint_order, l, rad, dr, dim, rad, dr, dim, func, 0, fk); - - for (int ik = 0; ik < dim; ik++) - { - double diff = fk[ik]- sqrt(ModuleBase::PI/4/ci)/pow(2.0*ci, l+1)* std::pow(rad[ik], l) * exp(-rad[ik]*rad[ik]/4/ci); - std::cout << rad[ik] << " " << fk[ik] << " " << sqrt(ModuleBase::PI/4/ci)/pow(2.0*ci, l+1)*pow(rad[ik], l)*exp(-rad[ik]*rad[ik]/4/ci) - << " " << std::log(fabs(diff))/std::log(10.0) << std::endl; - } - - delete[] rad; - delete[] func; - delete[] fk; - return; -} - -void Mathzone_Add1::test2 () -{ - int polint_order =3; - int N = 200; - int ci = 1; - int l = 0; - double rmax = 20; - double dr = rmax/(N-1); - - double dk = ModuleBase::PI / rmax /2; -// double kmax = PI / dr; -// double dk = dr; - - double* rad = new double[N]; - double* kad = new double[N]; - double* func = new double[N]; - double* fk = new double[N]; - double* fr = new double[N]; - for (int ir = 0; ir < N; ir++) - { - rad[ir] = ir * dr; - kad[ir] = ir * dk; - func[ir] = pow(rad[ir], l) * exp(-ci*rad[ir]*rad[ir]); - fk[ir] = 0.0; - fr[ir] = 0.0; - } - - Sbt_new (polint_order, l, kad, dk, N, rad, dr, N, func, 0, fk); - -/* - for (int ik = 0; ik < N; ik++) - { - double diff = fk[ik]- sqrt(PI/4/ci)/pow(2*ci, l+1)*pow(kad[ik], l)*exp(-kad[ik]*kad[ik]/4/ci); - std::cout << kad[ik] << " " << fk[ik] << " " << sqrt(PI/4/ci)/pow(2*ci, l+1)*pow(kad[ik], l)*exp(-kad[ik]*kad[ik]/4/ci) - << " " << log(fabs(diff))/log(10) << std::endl; - } - ModuleBase::QUIT (); -*/ - Sbt_new (polint_order, l, rad, dr, N, kad, dk, N, fk, 0, fr); - - for (int ir = 0; ir < N; ir++) - { - std::cout << ir*dr << " " << func[ir] << " " << fr[ir] *2.0 / ModuleBase::PI << " " << std::log(fabs(fr[ir]*2.0/ModuleBase::PI-func[ir]))/std::log(10.0) << std::endl; - } - - - delete[] rad; - delete[] func; - delete[] fk; - delete[] fr; - delete[] kad; - return; -} - -double Mathzone_Add1::Polynomial_Interpolation -( - const double* xa, - const double* ya, - const int& n, - const double& x -) -{ - ModuleBase::timer::tick("Mathzone_Add1","Polynomial_Interpolation"); - - int i, m, ns; - double den, dif, dift, ho, hp, w, rs, drs; - - //zero offset - const double* Cxa = xa - 1; - const double* Cya = ya - 1; - - double* cn = new double[n+1]; - double* dn = new double[n+1]; - - ns = 1; - dif = fabs(x - Cxa[1]); - - for(i = 1; i <= n; i++) - { - dift = fabs(x - Cxa[i]); - if(dift < dif) - { - ns = i; - dif = dift; - } - cn[i] = Cya[i]; - dn[i] = Cya[i]; - } - - rs = Cya[ns--]; - - for(m = 1; m < n; m++) - { - for(i = 1; i <= n-m; i++) - { - ho = Cxa[i] - x; - hp = Cxa[i+m] - x; - w = cn[i+1] - dn[i]; - - den = ho - hp; - if(den == 0.0) - { - std::cout << "Two Xs are equal" << std::endl; - // ModuleBase::WARNING_QUIT("Mathzone_Add1::Polynomial_Interpolation","Two Xs are equal"); - exit(0); // mohan update 2021-05-06 - } - den = w / den; - - dn[i] = hp * den; - cn[i] = ho * den; - } - if(2 * ns < n-m) drs = cn[ns+1]; - else drs = dn[ns--]; - - rs += drs; - } - - delete[] cn; - delete[] dn; - - return rs; - ModuleBase::timer::tick("Mathzone_Add1","Polynomial_Interpolation"); - -} - void Mathzone_Add1::SplineD2 // modified by pengfei 13-8-8 add second derivative as a condition ( @@ -1024,124 +182,7 @@ void Mathzone_Add1::Cubic_Spline_Interpolation ModuleBase::timer::tick("Mathzone_Add1","Cubic_Spline_Interpolation"); } -// Interpolation for Numerical Orbitals -double Mathzone_Add1::RadialF -( - const double* rad, - const double* rad_f, - const int& msh, - const int& l, - const double& R -) -{ - ModuleBase::timer::tick("Mathzone_Add1","RadialF"); - - int mp_min, mp_max, m; - double h1, h2, h3, f1, f2, f3, f4; - double g1, g2, x1, x2, y1, y2, f; - double c, result; - - mp_min = 0; - mp_max = msh - 1; - - //assume psir behaves like r**l - if (R < rad[0]) - { - if (l == 0) - { - f = rad_f[0]; - } - else - { - c = rad_f[0] / pow(rad[0], l); - f = pow(R, l) * c; - } - } - else if (rad[mp_max] < R) - { - f = 0.0; - } - else - { - do - { - m = (mp_min + mp_max)/2; - if (rad[m] < R) mp_min = m; - else mp_max = m; - } - while((mp_max-mp_min)!=1); - m = mp_max; - - if (m < 2) - { - m = 2; - } - else if (msh <= m) - { - m = msh - 2; - } - - /**************************************************** - Spline like interpolation - ****************************************************/ - - if (m == 1) - { - h2 = rad[m] - rad[m-1]; - h3 = rad[m+1] - rad[m]; - - f2 = rad_f[m-1]; - f3 = rad_f[m]; - f4 = rad_f[m+1]; - - h1 = -(h2+h3); - f1 = f4; - } - else if (m == (msh-1)) - { - h1 = rad[m-1] - rad[m-2]; - h2 = rad[m] - rad[m-1]; - - f1 = rad_f[m-2]; - f2 = rad_f[m-1]; - f3 = rad_f[m]; - - h3 = -(h1+h2); - f4 = f1; - } - else - { - h1 = rad[m-1] - rad[m-2]; - h2 = rad[m] - rad[m-1]; - h3 = rad[m+1] - rad[m]; - - f1 = rad_f[m-2]; - f2 = rad_f[m-1]; - f3 = rad_f[m]; - f4 = rad_f[m+1]; - } - - //Calculate the value at R - - g1 = ((f3-f2)*h1/h2 + (f2-f1)*h2/h1)/(h1+h2); - g2 = ((f4-f3)*h2/h3 + (f3-f2)*h3/h2)/(h2+h3); - - x1 = R - rad[m-1]; - x2 = R - rad[m]; - y1 = x1/h2; - y2 = x2/h2; - - f = y2*y2*(3.0*f2 + h2*g1 + (2.0*f2 + h2*g1)*y2) - + y1*y1*(3.0*f3 - h2*g2 - (2.0*f3 - h2*g2)*y1); - } - - result = f; - - ModuleBase::timer::tick("Mathzone_Add1","RadialF"); - return result; -} - -// Interpolation for Numerical Orbitals +/// Interpolation for Numerical Orbitals double Mathzone_Add1::Uni_RadialF ( const double* old_phi, @@ -1368,4 +409,4 @@ void Mathzone_Add1::Uni_Deriv_Phi ModuleBase::timer::tick("Mathzone_Add1", "Uni_Deriv_Phi"); } -} \ No newline at end of file +} diff --git a/source/module_base/mathzone_add1.h b/source/module_base/mathzone_add1.h index 68d76560d45..05565ff3c2f 100644 --- a/source/module_base/mathzone_add1.h +++ b/source/module_base/mathzone_add1.h @@ -1,8 +1,8 @@ #ifndef MATHZONE_ADD1_H #define MATHZONE_ADD1_H -#include #include +#include namespace ModuleBase { @@ -10,188 +10,74 @@ namespace ModuleBase /************************************************************************ LiaoChen add @ 2010/03/09 to add efficient functions in LCAO calculation ************************************************************************/ -//Only used in module_orbital +// Only used in module_orbital class Mathzone_Add1 { -public: - + public: Mathzone_Add1(); ~Mathzone_Add1(); - static void Sph_Bes (double x, int lmax, double *sb, double *dsb); - - //calculate jlx and its derivatives - static void Spherical_Bessel - ( - const int &msh, //number of grid points - const double *r,//radial grid - const double &q, // - const int &l, //angular momentum - double *sj, //jl(1:msh) = j_l(q*r(i)),spherical bessel function - double *sjp - ); - - /*********************************************** - function : integrate function within [a,b] - formula : \int func(r) dr - ***********************************************/ - static double uni_simpson (const double* func, const int& mshr, const double& dr); - - /******************************************************** - function : radial fourier transform - formula : F(k)_{n,l} = \int r^{n} dr jl(kr) \phi(r) - input : - int n (exponential index) - int aml (angular momentum) - int mshk (number of kpoints) - double* arr_k (an array of kpoints) - int mshr (number of grids) - double* ri (radial points) - double dr (delta r) - double* phir (function to be transformed) - output : - double* phik (with size of mshk) - ********************************************************/ - static void uni_radfft - ( - const int& n, - const int& aml, - const int& mshk, - const double* arr_k, - const int& mshr, - const double* ri, - const double& dr, - const double* phir, - double* phik - ); - - static void Sbt_new - ( - const int& polint_order, - const int& l, - const double* k, - const double& dk, - const int& mshk, - const double* r, - const double& dr, - const int& mshr, - const double* fr, - const int& rpow, - double* fk - ); + static double dualfac(const int& l); + static double factorial(const int& l); + /** + * @brief calculate second derivatives for cubic + * spline interpolation + * + * @param[in] rad x before interpolation + * @param[in] rad_f f(x) before interpolation + * @param[in] mesh number of x before interpolation + * @param[in] yp1 f'(0) boundary condition + * @param[in] ypn f'(n) boundary condition + * @param[out] y2 f''(x) + */ + static void SplineD2(const double* rad, + const double* rad_f, + const int& mesh, + const double& yp1, + const double& ypn, + double* y2); - static double dualfac (const int& l); - static double factorial (const int& l); - static double fourier_cosine_transform - ( - const double* func, - const double* r, - const int& mshr, - const double& dr, - const double& k - ); - - static double fourier_sine_transform - ( - const double* func, - const double* r, - const int& mshr, - const double& dr, - const double& k - ); - -/* - static void RecursiveRadfft - ( - const int& l, - const double** sb, - const double** dsb, - const double* arr_k, - const double* kpoint, - const int& kmesh, - const double& dk, - double& rs, - double& drs - ); -*/ + /** + * @brief cubic spline interpolation + * + * @param[in] rad x before interpolation + * @param[in] rad_f f(x) before inpterpolation + * @param[in] y2 f''(x) before interpolation + * @param[in] mesh number of x before interpolation + * @param[in] r x after interpolation + * @param[in] rsize number of x after interpolation + * @param[out] y f(x) after interpolation + * @param[out] dy f'(x) after interpolation + */ + static void Cubic_Spline_Interpolation(const double* const rad, + const double* const rad_f, + const double* const y2, + const int& mesh, + const double* const r, + const int& rsize, + double* const y, + double* const dy); - static double Polynomial_Interpolation - ( - const double* xa, - const double* ya, - const int& n, - const double& x - ); - - static void SplineD2 - ( - const double *rad, - const double *rad_f, - const int& mesh, - const double &yp1, - const double &ypn, - double* y2 - ); + /** + * @brief "spline like interpolation" of a uniform + * funcation of r + * + * @param[in] rad_f f(x) before interpolation + * @param[in] msh number of x known + * @param[in] dr uniform distance of x + * @param R f(R) is to be calculated + * @return double f(R) + */ + static double Uni_RadialF(const double* rad_f, const int& msh, const double& dr, const double& R); - static void Cubic_Spline_Interpolation - ( - const double * const rad, - const double * const rad_f, - const double * const y2, - const int& mesh, - const double * const r, - const int& rsize, - double * const y, - double * const dy - ); - - static double RadialF - ( - const double* rad, - const double* rad_f, - const int& msh, - const int& l, - const double& R - ); - - static double Uni_RadialF - ( - const double* rad_f, - const int& msh, - const double& dr, - const double& R - ); - - static void test (); - static void test2 (); - - - static void Uni_Deriv_Phi - ( - const double *radf, - const int &mesh, - const double &dr, - const int &nd, - double* phind - ); - - private: - const static int sph_lmax = 20; - static double** c_ln_c; - static double** c_ln_s; - static bool flag_jlx_expand_coef; + static void Uni_Deriv_Phi(const double* radf, const int& mesh, const double& dr, const int& nd, double* phind); - static void expand_coef_jlx (); - static double pol_seg_int - ( - const int& polint_order, - const double* coef, - const int& n, - const double* k, - const int& ik - ); + private: + const static int sph_lmax = 20; + static double** c_ln_c; + static double** c_ln_s; }; -} +} // namespace ModuleBase #endif diff --git a/source/module_base/realarray.cpp b/source/module_base/realarray.cpp index e6a0e60ddbe..757bb9ee386 100644 --- a/source/module_base/realarray.cpp +++ b/source/module_base/realarray.cpp @@ -53,6 +53,22 @@ realArray::realArray(const int d1,const int d2,const int d3,const int d4) ++arrayCount; } +realArray::realArray(const realArray &cd) +{ + this->size = cd.getSize(); + this->ptr = new double[size]; + for (int i = 0; i < size; i++) + this->ptr[i] = cd.ptr[i]; + this->dim = cd.dim; + this->bound1 = cd.bound1; + this->bound2 = cd.bound2; + this->bound3 = cd.bound3; + this->bound4 = cd.bound4; + + ++arrayCount; +} + + //******************************** // // Destructor for class realArray @@ -65,8 +81,8 @@ realArray ::~realArray() void realArray::freemem() { - delete [] ptr; - ptr = NULL; + delete [] ptr; + ptr = NULL; } void realArray::create(const int d1,const int d2,const int d3,const int d4) @@ -173,4 +189,4 @@ void realArray::zero_out(void) return; } -} \ No newline at end of file +} diff --git a/source/module_base/realarray.h b/source/module_base/realarray.h index 45cc5fc84c4..1dc2d7deeff 100644 --- a/source/module_base/realarray.h +++ b/source/module_base/realarray.h @@ -5,82 +5,162 @@ #ifndef REALARRAY_H #define REALARRAY_H -#include +#include #include #include -#include +#include #ifdef _MCD_CHECK //#include "./src_parallel/mcd.h" #endif - namespace ModuleBase { - +/** + * @brief double float array + * + */ class realArray { -public: - double *ptr; - - realArray(const int d1 = 1 ,const int d2 = 1,const int d3 = 1); - realArray(const int d1, const int d2,const int d3,const int d4); - ~realArray(); - - void create(const int d1,const int d2,const int d3); - void create(const int d1,const int d2,const int d3,const int d4); - - const realArray &operator=(const realArray &right); - const realArray &operator=(const double &right); - - double &operator()(const int d1,const int d2,const int d3); - double &operator()(const int d1,const int d2,const int d3,const int d4); - - const double &operator()(const int d1,const int d2,const int d3)const; - const double &operator()(const int d1,const int d2,const int d3,const int d4)const; - - void zero_out(void); - - int getSize() const - { return size;} - - int getDim() const - { return dim;} - - int getBound1() const - { return bound1;} - - int getBound2() const - { return bound2;} - - int getBound3() const - { return bound3;} - - int getBound4() const - { return bound4;} - - static int getArrayCount(void) - { return arrayCount;} - -private: - int size; - int dim; - int bound1, bound2, bound3, bound4; - static int arrayCount; - - void freemem(); + public: + double *ptr; + + realArray(const int d1 = 1, const int d2 = 1, const int d3 = 1); + realArray(const int d1, const int d2, const int d3, const int d4); + ~realArray(); + + /** + * @brief create 3 dimensional real array + * + * @param[in] d1 The first dimension size + * @param[in] d2 The second dimension size + * @param[in] d3 The third dimension size + */ + void create(const int d1, const int d2, const int d3); + void create(const int d1, const int d2, const int d3, const int d4); + + realArray(const realArray &cd); + + /** + * @brief Equal a realArray to another one + * + * @param right + * @return const realArray& + */ + const realArray &operator=(const realArray &right); + /** + * @brief Set all value of an array to a double float number + * + * @param right + * @return const realArray& + */ + const realArray &operator=(const double &right); + + /** + * @brief Access elements by using operator "()" + * + * @param d1 + * @param d2 + * @param d3 + * @return double& + */ + double &operator()(const int d1, const int d2, const int d3); + double &operator()(const int d1, const int d2, const int d3, const int d4); + + /** + * @brief Access elements by using "()" through pointer + * without changing its elements + * + * @param d1 + * @param d2 + * @param d3 + * @return const double& + */ + const double &operator()(const int d1, const int d2, const int d3) const; + const double &operator()(const int d1, const int d2, const int d3, const int d4) const; + + /** + * @brief Set all elements of an IntArray to zero + * + */ + void zero_out(void); + + /** + * @brief Get the Size object + * + * @return int + */ + int getSize() const + { + return size; + } + + /** + * @brief Get the Dim object + * i.e. the dimension of a real array + * + * @return int + */ + int getDim() const + { + return dim; + } + + /** + * @brief Get the Bound1 object + * i.e. the first dimension size + * + * @return int + */ + int getBound1() const + { + return bound1; + } + + int getBound2() const + { + return bound2; + } + + int getBound3() const + { + return bound3; + } + + int getBound4() const + { + return bound4; + } + + /** + * @brief Get the Array Count object + * + * @return int + */ + static int getArrayCount(void) + { + return arrayCount; + } + + private: + int size; + int dim; + int bound1, bound2, bound3, bound4; + static int arrayCount; + + void freemem(); }; //************************************************** // set elements of a as zeros which a is 1_d array. //************************************************** -template -void zeros(T *u,const int n) +template void zeros(T *u, const int n) { - assert(n>0); - for (int i = 0;i < n;i++) u[i] = 0; + assert(n > 0); + for (int i = 0; i < n; i++) + u[i] = 0; } -} +} // namespace ModuleBase -#endif // realArray class +#endif // realArray class diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index 1241410e700..b746c915dfb 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -38,3 +38,51 @@ AddTest( LIBS ${math_libs} SOURCES complexmatrix_test.cpp ../complexmatrix.cpp ../matrix.cpp ) +AddTest( + TARGET base_matrix + LIBS ${math_libs} + SOURCES matrix_test.cpp ../matrix.cpp +) +AddTest( + TARGET base_sph_bessel_recursive + SOURCES sph_bessel_recursive_test.cpp ../sph_bessel_recursive-d1.cpp ../sph_bessel_recursive-d2.cpp +) +AddTest( + TARGET base_math_sphbes + SOURCES math_sphbes_test.cpp ../math_sphbes.cpp ../timer.cpp +) +AddTest( + TARGET base_realarray + SOURCES realarray_test.cpp ../realarray.cpp +) +AddTest( + TARGET base_intarray + SOURCES intarray_test.cpp ../intarray.cpp +) +AddTest( + TARGET base_vector3 + SOURCES vector3_test.cpp +) +AddTest( + TARGET base_mathzone + LIBS ${math_libs} + SOURCES mathzone_test.cpp ../mathzone.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../memory.cpp ../timer.cpp +) +AddTest( + TARGET base_math_polyint + SOURCES math_polyint_test.cpp ../math_polyint.cpp ../realarray.cpp ../timer.cpp +) +AddTest( + TARGET base_ylmreal + LIBS ${math_libs} + SOURCES math_ylmreal_test.cpp ../math_ylmreal.cpp ../ylm.cpp ../realarray.cpp ../timer.cpp ../matrix.cpp ../vector3.h +) +AddTest( + TARGET base_mathzone_add1 + LIBS ${math_libs} + SOURCES mathzone_add1_test.cpp ../mathzone_add1.cpp ../math_sphbes.cpp ../mathzone.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../memory.cpp ../timer.cpp +) +AddTest( + TARGET base_math_bspline + SOURCES math_bspline_test.cpp ../math_bspline.cpp +) diff --git a/source/module_base/test/intarray_test.cpp b/source/module_base/test/intarray_test.cpp new file mode 100644 index 00000000000..d6694abea38 --- /dev/null +++ b/source/module_base/test/intarray_test.cpp @@ -0,0 +1,308 @@ +#include "../intarray.h" +#include "gtest/gtest.h" + +/************************************************ + * unit test of class IntArray + ***********************************************/ + +/** + * - Tested Functions: + * - Construct + * - construct an int array (2 to 6 dimensions) + * - Creat + * - create an int array (2 to 6 dimensions) + * - GetArrayCount + * - get the total number of int array created + * - GetSize + * - get the total size of an int array + * - GetDim + * - get the dimension of an int array + * - ZeroOut + * - set all elements of an int array to zero + * - GetBound + * - get the size of each dimension of an int array + * - ArrayEqReal + * - set all value of an array to an int number + * - ArrayEqArray + * - equal an intarray to another intarray + * - Parentheses + * - access element by using operator"()" + * - ConstParentheses + * - access element by using "()" through pointer + * - without changing its elements + */ + +class IntArrayTest : public testing::Test +{ +protected: + ModuleBase::IntArray a2, a3, a4, a5, a6; + int aa = 11; + int bb = 1; + int count0; + int count1; + const int zero = 0; +}; + +TEST_F(IntArrayTest,GetArrayCount) +{ + count0 = ModuleBase::IntArray::getArrayCount(); + ModuleBase::IntArray c3, c4; + count1 = ModuleBase::IntArray::getArrayCount(); + EXPECT_EQ((count1-count0),2); +} + +TEST_F(IntArrayTest,Construct) +{ + ModuleBase::IntArray x2(1,5); + ModuleBase::IntArray x3(1,5,3); + ModuleBase::IntArray x4(1,7,3,4); + ModuleBase::IntArray x5(1,5,3,8,2); + ModuleBase::IntArray x6(1,7,3,4,3,2); + EXPECT_EQ(x2.getSize(),5); + EXPECT_EQ(x3.getSize(),15); + EXPECT_EQ(x4.getSize(),84); + EXPECT_EQ(x5.getSize(),240); + EXPECT_EQ(x6.getSize(),504); +} + +TEST_F(IntArrayTest,Create) +{ + a2.create(2,1); + a3.create(3,2,1); + a4.create(4,3,2,1); + a5.create(5,4,3,2,1); + a6.create(6,5,4,3,2,1); + EXPECT_EQ(a2.getSize(),2); + EXPECT_EQ(a3.getSize(),6); + EXPECT_EQ(a4.getSize(),24); + EXPECT_EQ(a5.getSize(),120); + EXPECT_EQ(a6.getSize(),720); +} + + +TEST_F(IntArrayTest,GetSize) +{ + ModuleBase::IntArray x3(1,5,3); + ModuleBase::IntArray x4(1,7,3,4); + EXPECT_EQ(x3.getSize(),15); + EXPECT_EQ(x4.getSize(),84); +} + +TEST_F(IntArrayTest,GetDim) +{ + a2.create(2,3); + a3.create(3,5,1); + a4.create(4,3,7,1); + a5.create(5,4,1,2,1); + a6.create(6,5,9,3,2,1); + EXPECT_EQ(a2.getDim(),2); + EXPECT_EQ(a3.getDim(),3); + EXPECT_EQ(a4.getDim(),4); + EXPECT_EQ(a5.getDim(),5); + EXPECT_EQ(a6.getDim(),6); +} + +TEST_F(IntArrayTest,ZeroOut) +{ + a2.create(2,3); + a3.create(3,5,1); + a4.create(4,3,7,1); + a5.create(5,4,1,2,1); + a6.create(6,5,9,3,2,1); + a2.zero_out(); + a3.zero_out(); + a4.zero_out(); + a5.zero_out(); + a6.zero_out(); + for (int i=0;i + +#define doublethreshold 1e-9 + +/************************************************ +* unit test of class PolyInt +***********************************************/ + +/** + * This unit test is to verify the accuracy of + * interpolation method on the function sin(x)/x + * with a interval of 0.01. + * sin(x)/x is one of the solution of spherical bessel + * function when l=0. + * + * - Tested function: + * - 4 types of Polynomial_Interpolation + * - Polynomial_Interpolation_xy + */ + + +class bessell0 : public testing::Test +{ + protected: + + int TableLength = 400; + double interval = 0.01; + ModuleBase::realArray table3,table4; + ModuleBase::realArray y3; + double *tablex; + double *tabley; + + double sinc(double x) {return sin(x)/x;} + + void SetUp() + { + tablex = new double[TableLength]; + tabley = new double[TableLength]; + table3.create(1,1,TableLength); + table4.create(1,1,1,TableLength); + y3.create(1,1,TableLength); + + for(int i=1;i + +#ifdef __MPI +#include"mpi.h" +#endif + +#include"gtest/gtest.h" + +#define doublethreshold 1e-7 + + +/************************************************ +* unit test of class Integral +***********************************************/ + +/** + * Note: this unit test try to ensure the invariance + * of the spherical Bessel produced by class Sphbes, + * and the reference results are produced by ABACUS + * at 2022-1-27. + * + * Tested function: + * - Spherical_Bessel. + * - Spherical_Bessel_Roots + * + */ + +double mean(const double* vect, const int totN) +{ + double meanv = 0.0; + for (int i=0; i< totN; ++i) {meanv += vect[i]/totN;} + return meanv; +} + +class Sphbes : public testing::Test +{ + protected: + + int msh = 700; + int l0 = 0; + int l1 = 1; + int l2 = 2; + int l3 = 3; + int l4 = 4; + int l5 = 5; + int l6 = 6; + int l7 = 7; + double q = 1.0; + double *r = new double[msh]; + double *jl = new double[msh]; + + void SetUp() + { + for(int i=0; i + +#define doublethreshold 1e-12 + +/************************************************ +* unit test of class YlmReal and Ylm +***********************************************/ + +/** + * For lmax <5 cases, the reference values are calculated by the formula from + * https://formulasearchengine.com/wiki/Table_of_spherical_harmonics. Note, these + * formula lack of the Condon–Shortley phase (-1)^m, and in this unit test, item + * (-1)^m is multiplied. + * For lmax >=5, the reference values are calculated by YlmReal::Ylm_Real. + * + * - Tested functions of class YlmReal + * - Ylm_Real + * - Ylm_Real2 + * - rlylm + * + * - Tested functions of class Ylm + * - get_ylm_real + * - sph_harm + * - rl_sph_harm + * - grad_rl_sph_harm + * - + */ + + + +//mock functions of WARNING_QUIT and WARNING +namespace ModuleBase +{ + void WARNING_QUIT(const std::string &file,const std::string &description) {return ;} + void WARNING(const std::string &file,const std::string &description) {return ;} +} + + +class YlmRealTest : public testing::Test +{ + protected: + + int lmax = 7; //maximum angular quantum number + int ng = 4; //test the 4 selected points on the sphere + int nylm = 64; //total Ylm number; + + ModuleBase::matrix ylm; //Ylm + ModuleBase::Vector3 *g; //vectors of the 4 points + double *ref; //reference of Ylm + double *rly; //Ylm + double (*rlgy)[3]; //the gradient of Ylm + std::vector rlyvector; //Ylm + std::vector> rlgyvector; //the gradient of Ylm + + //Ylm function + inline double norm(const double &x, const double &y, const double &z) {return sqrt(x*x + y*y + z*z);} + double y00(const double &x, const double &y, const double &z) {return 1.0/2.0/sqrt(M_PI);} + double y10(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return sqrt(3.0/(4.0*M_PI)) * z / r;} + double y11(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*sqrt(3.0/(4.*M_PI)) * x / r;} + double y1m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*sqrt(3./(4.*M_PI)) * y / r;} // y1m1 means Y1,-1 + double y20(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(5./M_PI) * (-1.*x*x - y*y + 2.*z*z) / (r*r);} + double y21(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./2. * sqrt(15./M_PI) * (z*x) / (r*r);} + double y2m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./2. * sqrt(15./M_PI) * (z*y) / (r*r);} + double y22(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(15./M_PI) * (x*x - y*y) / (r*r);} + double y2m2(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./2. * sqrt(15./M_PI) * (x*y) / (r*r);} + double y30(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(7./M_PI) * z*(2.*z*z-3.*x*x-3.*y*y) / (r*r*r);} + double y31(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(21./2./M_PI) * x*(4.*z*z-x*x-y*y) / (r*r*r);} + double y3m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(21./2./M_PI) * y*(4.*z*z-x*x-y*y) / (r*r*r);} + double y32(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./4. * sqrt(105./M_PI) * (x*x - y*y)*z / (r*r*r);} + double y3m2(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 1./2. * sqrt(105./M_PI) * x*y*z / (r*r*r);} + double y33(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(35./2./M_PI) * x*(x*x - 3.*y*y) / (r*r*r);} + double y3m3(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*1./4. * sqrt(35./2./M_PI) * y*(3.*x*x - y*y) / (r*r*r);} + double y40(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./16.*sqrt(1./M_PI) * (35.*z*z*z*z - 30.*z*z*r*r + 3*r*r*r*r) / (r*r*r*r);} + double y41(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(5./2./M_PI) * x*z*(7.*z*z - 3*r*r) / (r*r*r*r);} + double y4m1(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(5./2./M_PI) * y*z*(7.*z*z - 3.*r*r) / (r*r*r*r);} + double y42(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./8.*sqrt(5./M_PI) * (x*x-y*y)*(7.*z*z-r*r) / (r*r*r*r);} + double y4m2(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./4.*sqrt(5./M_PI) * x*y*(7.*z*z - r*r) / (r*r*r*r);} + double y43(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(35./2./M_PI) * x*z*(x*x - 3.*y*y) / (r*r*r*r);} + double y4m3(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return -1.0*3./4.*sqrt(35./2./M_PI) * y*z*(3.*x*x - y*y) / (r*r*r*r);} + double y44(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./16.*sqrt(35./M_PI) * (x*x*(x*x - 3.*y*y) - y*y*(3.*x*x-y*y)) / (r*r*r*r);} + double y4m4(const double &x, const double &y, const double &z) {double r=norm(x,y,z); return 3./4.*sqrt(35./M_PI) * x*y*(x*x - y*y) / (r*r*r*r);} + + //the reference values are calculated by ModuleBase::Ylm::grad_rl_sph_harm + //1st dimension: example, 2nd dimension: Ylm, 3rd dimension: dx/dy/dz + double rlgyref[4][64][3] = { + { { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 4.88603e-01}, {-4.88603e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -4.88603e-01, 0.00000e+00}, {-6.30783e-01, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, -1.09255e+00}, + { 0.00000e+00, -0.00000e+00, 0.00000e+00}, { 1.09255e+00, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, 1.09255e+00, -0.00000e+00}, + {-0.00000e+00, 0.00000e+00, -1.11953e+00}, { 1.37114e+00, 0.00000e+00, -0.00000e+00}, { 0.00000e+00, 4.57046e-01, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 1.44531e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, {-1.77013e+00, 0.00000e+00, -0.00000e+00}, + { 0.00000e+00, -1.77013e+00, 0.00000e+00}, { 1.26943e+00, 0.00000e+00, -0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.00714e+00}, + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, {-1.89235e+00, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, -9.46175e-01, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, -1.77013e+00}, { 0.00000e+00, -0.00000e+00, 0.00000e+00}, { 2.50334e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 2.50334e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 1.75425e+00}, {-2.26473e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -4.52947e-01, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, -2.39677e+00}, {-0.00000e+00, -0.00000e+00, 0.00000e+00}, + { 2.44619e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 1.46771e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.07566e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, {-3.28191e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -3.28191e+00, 0.00000e+00}, + {-1.90708e+00, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, -2.91311e+00}, { 0.00000e+00, -0.00000e+00, 0.00000e+00}, + { 2.76362e+00, 0.00000e+00, -0.00000e+00}, {-0.00000e+00, 9.21205e-01, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.76362e+00}, + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, {-3.02739e+00, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, -2.01826e+00, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, -2.36662e+00}, { 0.00000e+00, -0.00000e+00, 0.00000e+00}, { 4.09910e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 4.09910e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, -2.38995e+00}, { 3.16161e+00, 0.00000e+00, -0.00000e+00}, + { 0.00000e+00, 4.51658e-01, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 3.31900e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-3.28564e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -1.40813e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, -3.11349e+00}, + {-0.00000e+00, -0.00000e+00, 0.00000e+00}, { 3.63241e+00, 0.00000e+00, -0.00000e+00}, { 0.00000e+00, 2.59458e+00, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 2.64596e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, {-4.95014e+00, 0.00000e+00, -0.00000e+00}, + { 0.00000e+00, -4.95014e+00, 0.00000e+00} + }, + { + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 4.88603e-01}, {-4.88603e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -4.88603e-01, 0.00000e+00}, { 0.00000e+00, -6.30783e-01, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -0.00000e+00, -1.09255e+00}, { 0.00000e+00, -1.09255e+00, 0.00000e+00}, { 1.09255e+00, 0.00000e+00, -0.00000e+00}, + { 0.00000e+00, -0.00000e+00, -1.11953e+00}, { 4.57046e-01, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 1.37114e+00, -0.00000e+00}, + { 0.00000e+00, -0.00000e+00, -1.44531e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 1.77013e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 1.77013e+00, 0.00000e+00}, { 0.00000e+00, 1.26943e+00, -0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 2.00714e+00}, { 0.00000e+00, 1.89235e+00, -0.00000e+00}, {-9.46175e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 1.77013e+00}, { 0.00000e+00, 2.50334e+00, -0.00000e+00}, + {-2.50334e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 1.75425e+00}, {-4.52947e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -2.26473e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.39677e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-1.46771e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -2.44619e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.07566e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, {-3.28191e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -3.28191e+00, 0.00000e+00}, + { 0.00000e+00, -1.90708e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -0.00000e+00, -2.91311e+00}, + { 0.00000e+00, -2.76362e+00, 0.00000e+00}, { 9.21205e-01, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -0.00000e+00, -2.76362e+00}, { 0.00000e+00, -3.02739e+00, 0.00000e+00}, { 2.01826e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -0.00000e+00, -2.36662e+00}, { 0.00000e+00, -4.09910e+00, 0.00000e+00}, + { 4.09910e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -0.00000e+00, -2.38995e+00}, { 4.51658e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 3.16161e+00, -0.00000e+00}, { 0.00000e+00, -0.00000e+00, -3.31900e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, + { 1.40813e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 3.28564e+00, -0.00000e+00}, { 0.00000e+00, -0.00000e+00, -3.11349e+00}, + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 2.59458e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 3.63241e+00, -0.00000e+00}, + { 0.00000e+00, 0.00000e+00, -2.64596e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 4.95014e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 4.95014e+00, -0.00000e+00} + }, + { + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 4.88603e-01}, {-4.88603e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -4.88603e-01, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 1.26157e+00}, {-1.09255e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -1.09255e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.22045e-16}, {-0.00000e+00, 0.00000e+00, -0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 2.23906e+00}, {-1.82818e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -1.82818e+00, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 8.81212e-16}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, {-1.84324e-16, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 5.55112e-17, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 3.38514e+00}, {-2.67619e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -2.67619e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 2.30756e-15}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-5.52973e-16, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 1.66533e-16, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 4.67801e+00}, {-3.62357e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -3.62357e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 4.87108e-15}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-1.22267e-15, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 3.68219e-16, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, { 4.93038e-32, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -6.16298e-33, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 6.10264e+00}, {-4.66097e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -4.66097e+00, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 8.98664e-15}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, {-2.30221e-15, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, 6.93334e-16, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + { 1.77767e-31, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -2.22209e-32, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 7.64784e+00}, {-5.78122e+00, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -5.78122e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 1.51096e-14}, {-0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-3.91011e-15, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 1.17757e-15, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, + {-0.00000e+00, 0.00000e+00, 0.00000e+00}, { 4.67737e-31, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, -5.84671e-32, 0.00000e+00}, + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 1.13319e-47, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -1.41649e-48, 0.00000e+00} + }, + { + { 0.00000e+00, 0.00000e+00, 0.00000e+00}, { 0.00000e+00, 0.00000e+00, 4.88603e-01}, {-4.88603e-01, 0.00000e+00, 0.00000e+00}, + { 0.00000e+00, -4.88603e-01, 0.00000e+00}, { 3.64183e-01, 3.64183e-01, -7.28366e-01}, { 6.30783e-01, -0.00000e+00, 6.30783e-01}, + {-0.00000e+00, 6.30783e-01, 6.30783e-01}, {-6.30783e-01, 6.30783e-01, -1.66533e-16}, {-6.30783e-01, -6.30783e-01, 0.00000e+00}, + {-7.46353e-01, -7.46353e-01, 0.00000e+00}, { 0.00000e+00, 3.04697e-01, -1.21879e+00}, { 3.04697e-01, 0.00000e+00, -1.21879e+00}, + { 9.63537e-01, -9.63537e-01, 4.01253e-16}, { 9.63537e-01, 9.63537e-01, 9.63537e-01}, {-4.44089e-16, 1.18009e+00, -2.22045e-16}, + {-1.18009e+00, -1.11022e-16, 0.00000e+00}, { 4.88603e-01, 4.88603e-01, 1.30294e+00}, {-1.03006e+00, -7.72548e-01, 7.72548e-01}, + {-7.72548e-01, -1.03006e+00, 7.72548e-01}, {-7.28366e-01, 7.28366e-01, -5.25363e-16}, {-3.64183e-01, -3.64183e-01, -2.18510e+00}, + { 7.69185e-16, -2.04397e+00, -6.81324e-01}, { 2.04397e+00, 1.92296e-16, 6.81324e-01}, { 9.63537e-01, 9.63537e-01, -1.44756e-16}, + {-9.63537e-01, 9.63537e-01, -5.55112e-17}, { 5.19779e-01, 5.19779e-01, -1.81923e+00}, { 1.40917e+00, 8.05238e-01, 8.05238e-01}, + { 8.05238e-01, 1.40917e+00, 8.05238e-01}, { 0.00000e+00, -4.44089e-16, 3.24739e-16}, {-1.06523e+00, -1.06523e+00, 2.13046e+00}, + {-2.17439e-01, 1.73951e+00, 1.73951e+00}, {-1.73951e+00, 2.17439e-01, -1.73951e+00}, {-1.84503e+00, -1.84503e+00, -9.22517e-01}, + { 1.84503e+00, -1.84503e+00, 6.58625e-16}, { 1.45863e+00, 1.11022e-15, 0.00000e+00}, {-8.88178e-16, 1.45863e+00, 0.00000e+00}, + {-1.46807e+00, -1.46807e+00, 5.87227e-01}, {-4.48502e-01, -3.36617e-16, -2.24251e+00}, {-3.36617e-16, -4.48502e-01, -2.24251e+00}, + { 7.09144e-01, -7.09144e-01, 1.87222e-16}, { 2.12743e+00, 2.12743e+00, -9.38779e-16}, { 7.09144e-01, -5.11006e-16, -2.12743e+00}, + { 1.02201e-15, -7.09144e-01, 2.12743e+00}, { 1.81260e+00, 1.81260e+00, 2.58943e+00}, {-2.07154e+00, 2.07154e+00, -1.66969e-15}, + {-3.03637e+00, -2.31111e-15, -6.07275e-01}, { 1.84889e-15, -3.03637e+00, -6.07275e-01}, { 1.05183e+00, -1.05183e+00, 5.77778e-17}, + { 1.05183e+00, 1.05183e+00, 4.03986e-17}, { 1.27464e+00, 1.27464e+00, 1.69952e+00}, {-1.28472e+00, -1.20442e+00, 1.92707e+00}, + {-1.20442e+00, -1.28472e+00, 1.92707e+00}, {-8.52285e-01, 8.52285e-01, -6.74704e-16}, {-1.50789e+00, -1.50789e+00, -2.95022e+00}, + {-1.11260e+00, -2.08612e+00, 9.27164e-01}, { 2.08612e+00, 1.11260e+00, -9.27164e-01}, {-3.07506e-01, -3.07506e-01, -3.69007e+00}, + { 1.23002e+00, -1.23002e+00, 2.28018e-15}, { 3.69007e+00, -1.53753e-01, 1.84503e+00}, {-1.53753e-01, 3.69007e+00, 1.84503e+00}, + {-2.35197e+00, 2.35197e+00, -8.00513e-16}, {-2.35197e+00, -2.35197e+00, -7.83988e-01}, { 1.37903e-15, -1.46671e+00, 9.77875e-17}, + { 1.46671e+00, 1.14919e-15, 1.34475e-16} + } + }; + + void SetUp() + { + ylm.create(nylm,ng); + g = new ModuleBase::Vector3[ng]; + g[0].set(1.0,0.0,0.0); + g[1].set(0.0,1.0,0.0); + g[2].set(0.0,0.0,1.0); + g[3].set(-1.0,-1.0,-1.0); + + rly = new double[nylm]; + rlyvector.resize(nylm); + rlgy = new double[nylm][3]; + rlgyvector.resize(nylm,std::vector(3)); + ref = new double[64*4]{ + y00(g[0].x, g[0].y, g[0].z), y00(g[1].x, g[1].y, g[1].z), y00(g[2].x, g[2].y, g[2].z), y00(g[3].x, g[3].y, g[3].z), + y10(g[0].x, g[0].y, g[0].z), y10(g[1].x, g[1].y, g[1].z), y10(g[2].x, g[2].y, g[2].z), y10(g[3].x, g[3].y, g[3].z), + y11(g[0].x, g[0].y, g[0].z), y11(g[1].x, g[1].y, g[1].z), y11(g[2].x, g[2].y, g[2].z), y11(g[3].x, g[3].y, g[3].z), + y1m1(g[0].x, g[0].y, g[0].z), y1m1(g[1].x, g[1].y, g[1].z), y1m1(g[2].x, g[2].y, g[2].z), y1m1(g[3].x, g[3].y, g[3].z), + y20(g[0].x, g[0].y, g[0].z), y20(g[1].x, g[1].y, g[1].z), y20(g[2].x, g[2].y, g[2].z), y20(g[3].x, g[3].y, g[3].z), + y21(g[0].x, g[0].y, g[0].z), y21(g[1].x, g[1].y, g[1].z), y21(g[2].x, g[2].y, g[2].z), y21(g[3].x, g[3].y, g[3].z), + y2m1(g[0].x, g[0].y, g[0].z), y2m1(g[1].x, g[1].y, g[1].z), y2m1(g[2].x, g[2].y, g[2].z), y2m1(g[3].x, g[3].y, g[3].z), + y22(g[0].x, g[0].y, g[0].z), y22(g[1].x, g[1].y, g[1].z), y22(g[2].x, g[2].y, g[2].z), y22(g[3].x, g[3].y, g[3].z), + y2m2(g[0].x, g[0].y, g[0].z), y2m2(g[1].x, g[1].y, g[1].z), y2m2(g[2].x, g[2].y, g[2].z), y2m2(g[3].x, g[3].y, g[3].z), + y30(g[0].x, g[0].y, g[0].z), y30(g[1].x, g[1].y, g[1].z), y30(g[2].x, g[2].y, g[2].z), y30(g[3].x, g[3].y, g[3].z), + y31(g[0].x, g[0].y, g[0].z), y31(g[1].x, g[1].y, g[1].z), y31(g[2].x, g[2].y, g[2].z), y31(g[3].x, g[3].y, g[3].z), + y3m1(g[0].x, g[0].y, g[0].z), y3m1(g[1].x, g[1].y, g[1].z), y3m1(g[2].x, g[2].y, g[2].z), y3m1(g[3].x, g[3].y, g[3].z), + y32(g[0].x, g[0].y, g[0].z), y32(g[1].x, g[1].y, g[1].z), y32(g[2].x, g[2].y, g[2].z), y32(g[3].x, g[3].y, g[3].z), + y3m2(g[0].x, g[0].y, g[0].z), y3m2(g[1].x, g[1].y, g[1].z), y3m2(g[2].x, g[2].y, g[2].z), y3m2(g[3].x, g[3].y, g[3].z), + y33(g[0].x, g[0].y, g[0].z), y33(g[1].x, g[1].y, g[1].z), y33(g[2].x, g[2].y, g[2].z), y33(g[3].x, g[3].y, g[3].z), + y3m3(g[0].x, g[0].y, g[0].z), y3m3(g[1].x, g[1].y, g[1].z), y3m3(g[2].x, g[2].y, g[2].z), y3m3(g[3].x, g[3].y, g[3].z), + y40(g[0].x, g[0].y, g[0].z), y40(g[1].x, g[1].y, g[1].z), y40(g[2].x, g[2].y, g[2].z), y40(g[3].x, g[3].y, g[3].z), + y41(g[0].x, g[0].y, g[0].z), y41(g[1].x, g[1].y, g[1].z), y41(g[2].x, g[2].y, g[2].z), y41(g[3].x, g[3].y, g[3].z), + y4m1(g[0].x, g[0].y, g[0].z), y4m1(g[1].x, g[1].y, g[1].z), y4m1(g[2].x, g[2].y, g[2].z), y4m1(g[3].x, g[3].y, g[3].z), + y42(g[0].x, g[0].y, g[0].z), y42(g[1].x, g[1].y, g[1].z), y42(g[2].x, g[2].y, g[2].z), y42(g[3].x, g[3].y, g[3].z), + y4m2(g[0].x, g[0].y, g[0].z), y4m2(g[1].x, g[1].y, g[1].z), y4m2(g[2].x, g[2].y, g[2].z), y4m2(g[3].x, g[3].y, g[3].z), + y43(g[0].x, g[0].y, g[0].z), y43(g[1].x, g[1].y, g[1].z), y43(g[2].x, g[2].y, g[2].z), y43(g[3].x, g[3].y, g[3].z), + y4m3(g[0].x, g[0].y, g[0].z), y4m3(g[1].x, g[1].y, g[1].z), y4m3(g[2].x, g[2].y, g[2].z), y4m3(g[3].x, g[3].y, g[3].z), + y44(g[0].x, g[0].y, g[0].z), y44(g[1].x, g[1].y, g[1].z), y44(g[2].x, g[2].y, g[2].z), y44(g[3].x, g[3].y, g[3].z), + y4m4(g[0].x, g[0].y, g[0].z), y4m4(g[1].x, g[1].y, g[1].z), y4m4(g[2].x, g[2].y, g[2].z), y4m4(g[3].x, g[3].y, g[3].z), + 0.000000000000000, 0.000000000000000, 0.935602579627389, 0.090028400200397, + -0.452946651195697, -0.000000000000000, -0.000000000000000, -0.348678494661834, + -0.000000000000000, -0.452946651195697, -0.000000000000000, -0.348678494661834, + -0.000000000000000, 0.000000000000000, 0.000000000000000, -0.000000000000000, + -0.000000000000000, -0.000000000000000, 0.000000000000000, -0.000000000000000, + 0.489238299435250, 0.000000000000000, -0.000000000000000, -0.376615818502422, + 0.000000000000000, -0.489238299435250, -0.000000000000000, 0.376615818502422, + 0.000000000000000, 0.000000000000000, 0.000000000000000, 0.532615198330370, + 0.000000000000000, 0.000000000000000, 0.000000000000000, -0.000000000000000, + -0.656382056840170, -0.000000000000000, -0.000000000000000, -0.168427714314628, + -0.000000000000000, -0.656382056840170, -0.000000000000000, -0.168427714314628, + -0.317846011338142, -0.317846011338142, 1.017107236282055, 0.226023830284901, + -0.000000000000000, -0.000000000000000, -0.000000000000000, 0.258942827786103, + -0.000000000000000, -0.000000000000000, -0.000000000000000, 0.258942827786103, + 0.460602629757462, -0.460602629757462, 0.000000000000000, -0.000000000000000, + 0.000000000000000, 0.000000000000000, 0.000000000000000, -0.409424559784410, + -0.000000000000000, -0.000000000000000, -0.000000000000000, 0.136474853261470, + -0.000000000000000, 0.000000000000000, -0.000000000000000, -0.136474853261470, + -0.504564900728724, -0.504564900728724, 0.000000000000000, -0.598002845308118, + -0.000000000000000, -0.000000000000000, 0.000000000000000, 0.000000000000000, + -0.000000000000000, -0.000000000000000, -0.000000000000000, 0.350610246256556, + -0.000000000000000, -0.000000000000000, -0.000000000000000, 0.350610246256556, + 0.683184105191914, -0.683184105191914, 0.000000000000000, -0.000000000000000, + 0.000000000000000, 0.000000000000000, 0.000000000000000, -0.202424920056864, + 0.000000000000000, 0.000000000000000, 1.092548430592079, -0.350435072502801, + 0.451658037912587, 0.000000000000000, -0.000000000000000, 0.046358202625865, + 0.000000000000000, 0.451658037912587, -0.000000000000000, 0.046358202625865, + 0.000000000000000, -0.000000000000000, 0.000000000000000, 0.000000000000000, + 0.000000000000000, 0.000000000000000, 0.000000000000000, 0.492067081245654, + -0.469376801586882, -0.000000000000000, -0.000000000000000, 0.187354445356332, + -0.000000000000000, 0.469376801586882, -0.000000000000000, -0.187354445356332, + 0.000000000000000, 0.000000000000000, 0.000000000000000, 0.355076798886913, + 0.000000000000000, 0.000000000000000, 0.000000000000000, -0.000000000000000, + 0.518915578720260, 0.000000000000000, -0.000000000000000, -0.443845998608641, + 0.000000000000000, 0.518915578720260, -0.000000000000000, -0.443845998608641, + 0.000000000000000, -0.000000000000000, 0.000000000000000, 0.000000000000000, + 0.000000000000000, 0.000000000000000, 0.000000000000000, 0.452635881587108, + -0.707162732524596, 0.000000000000000, -0.000000000000000, 0.120972027847095, + -0.000000000000000, 0.707162732524596, -0.000000000000000, -0.120972027847095 + } ; + } + + void TearDown() + { + delete [] g; + delete [] ref; + delete [] rly; + delete [] rlgy; + } +}; + +TEST_F(YlmRealTest,YlmReal) +{ + ModuleBase::YlmReal::Ylm_Real(nylm,ng,g,ylm); + for(int i=0;i direct, cartesian; +}; + +TEST_F(MathzoneTest, PointwiseProduct) +{ + std::vector aa, bb, cc; + for(int i=0;i<10;i++) + { + aa.push_back(i*i); + bb.push_back(i*2); + } + cc = ModuleBase::Mathzone::Pointwise_Product(aa,bb); + for(int i=0;i<10;i++) + { + EXPECT_EQ(cc[i],i*i*i*2); + } +} + +TEST_F(MathzoneTest, Direct2Cartesian) +{ + direct.set(0.1,0.2,0.4); + cartesian.set(0.368,2.02,10.68); + ModuleBase::Vector3 cartnew; + ModuleBase::Mathzone::Direct_to_Cartesian(direct.x, + direct.y, + direct.z, + R11, R12, R13, + R21, R22, R23, + R31, R32, R33, + cartnew.x, + cartnew.y, + cartnew.z); + EXPECT_NEAR(cartnew.x,cartesian.x, 1e-15); + EXPECT_NEAR(cartnew.y,cartesian.y, 1e-15); + EXPECT_NEAR(cartnew.z,cartesian.z, 1e-15); +} + +TEST_F(MathzoneTest, Cartesian2Direct) +{ + direct.set(0.1,0.2,0.4); + cartesian.set(0.368,2.02,10.68); + ModuleBase::Vector3 directnew; + ModuleBase::Mathzone::Cartesian_to_Direct(cartesian.x, + cartesian.y, + cartesian.z, + R11, R12, R13, + R21, R22, R23, + R31, R32, R33, + directnew.x, + directnew.y, + directnew.z); + EXPECT_NEAR(directnew.x,direct.x, 1e-15); + EXPECT_NEAR(directnew.y,direct.y, 1e-15); + EXPECT_NEAR(directnew.z,direct.z, 1e-15); +} diff --git a/source/module_base/test/matrix_test.cpp b/source/module_base/test/matrix_test.cpp new file mode 100644 index 00000000000..b1365da4565 --- /dev/null +++ b/source/module_base/test/matrix_test.cpp @@ -0,0 +1,330 @@ +#include"../matrix.h" +#include"gtest/gtest.h" + +/************************************************ +* unit test of class matrix and related functions +***********************************************/ + +/** + * - Tested functions of class matrix: + * - constructor: + * - constructed by nrow and ncloumn + * - constructed by a matrix + * - constructed by the rvalue of a matrix + * - function create + * - operator "=": assigned by a matrix or the rvalue of a matrix + * - operator "()": access the element + * - operator "*=", "+=", "-=" + * - function trace_on + * - function zero_out + * - function max/min/absmax + * - function norm + * - function print (not called in abacus, no need to test) + * + * - Tested functions related to class matrix + * - operator "+", "-", "*" between two matrixs + * - operator "*" between a double and a matrix, and reverse. + * - function transpose + * - function trace_on + * - function mdot + * + */ + +//a mock function of WARNING_QUIT, to avoid the uncorrected call by matrix.cpp at line 37. +namespace ModuleBase +{ + void WARNING_QUIT(const std::string &file,const std::string &description) {return ;} +} + +class matrixTest : public testing::Test +{ + protected: + ModuleBase::matrix m23a,m33a,m33b,m33c,m34a,m34b; + + void SetUp() + { + m23a.create(2,3); + for (int i=1;i<=6;++i) {m23a.c[i-1] = i*1.0;} + + m33a.create(3,3); + for (int i=1;i<=9;++i) {m33a.c[i-1] = i*1.0;} + + m33b.create(3,3); + for (int i=1;i<=9;++i) {m33b.c[i-1] = i*11.1;} + + m33c.create(3,3,true); + m34a.create(3,4,true); + m34b.create(3,4,true); + } + +}; + +TEST(matrix,ConstructorNrNc) +{ + ModuleBase::matrix m(3,4,true); + EXPECT_EQ(m.nr,3); + EXPECT_EQ(m.nc,4); + EXPECT_DOUBLE_EQ(m(0,0),0.0); +} + +TEST_F(matrixTest,ConstructorMatrix) +{ + ModuleBase::matrix m(m33a); + int mnr = m.nr; + EXPECT_EQ(mnr,m33a.nr); + EXPECT_EQ(m.nc,m33a.nc); + for (int i=0;i<9;++i) + { + EXPECT_DOUBLE_EQ(m.c[i],m33a.c[i]); + } +} + +TEST_F(matrixTest,ConstructorMtrixRValue) +{ + + ModuleBase::matrix m(3.0*m33a); + EXPECT_EQ(m.nr,m33a.nr); + EXPECT_EQ(m.nc,m33a.nc); + for (int i=0;i<9;++i) + { + EXPECT_DOUBLE_EQ(m.c[i],m33a.c[i] * 3.0); + } +} + +TEST_F(matrixTest,Create) +{ + m33a.create(13,14,true); + EXPECT_EQ(m33a.nr,13); + EXPECT_EQ(m33a.nc,14); + for(int i=0;i<13*14;++i) + { + EXPECT_DOUBLE_EQ(m33a.c[i],0.0); + } +} + +TEST_F(matrixTest,OperatorEqualMatrix) +{ + ModuleBase::matrix m; + m = m33a; + EXPECT_EQ(m.nr,m33a.nr); + EXPECT_EQ(m.nc,m33a.nc); + for (int i=0;i<9;++i) + { + EXPECT_DOUBLE_EQ(m.c[i],m33a.c[i]); + } + + m23a = m33a; + EXPECT_EQ(m23a.nr,m33a.nr); + EXPECT_EQ(m23a.nc,m33a.nc); +} + +TEST_F(matrixTest,OperatorEqualMatrixRvalue) +{ + ModuleBase::matrix m; + m = 3.0 * m33a; + EXPECT_EQ(m.nr,m33a.nr); + EXPECT_EQ(m.nc,m33a.nc); + for (int i=0;i<9;++i) + { + EXPECT_DOUBLE_EQ(m.c[i],m33a.c[i] * 3.0); + } +} + +TEST_F(matrixTest,OperatorParentheses) +{ + //EXPECT_DEATH(m33a(3,3),""); + //EXPECT_DEATH(m33a(-1,0),""); + m33a(0,0) = 1.1; + EXPECT_DOUBLE_EQ(m33a(0,0),1.1); +} + +TEST_F(matrixTest,OperatorMultiplyEqual) +{ + m33b = m33a; + m33a *= 11.1; + for (int i=0;ia3) < &vect) +{ + double meanv = 0.0; + + int totN = vect.size(); + for (int i=0; i< totN; ++i) {meanv += vect[i]/totN;} + + return meanv; +} + +TEST(SphBessel,D1) +{ + int lmax = 7; + int rmesh = 700; + double dx = 0.01; + + ModuleBase::Sph_Bessel_Recursive::D1 sphbesseld1; + sphbesseld1.set_dx(dx); + sphbesseld1.cal_jlx(lmax,rmesh); + std::vector> jlx = sphbesseld1.get_jlx(); + + ASSERT_EQ(jlx.size(),static_cast(lmax + 1)); + EXPECT_NEAR( mean(jlx[0])/0.2084468748396, 1.0, threshold); + EXPECT_NEAR( mean(jlx[1])/0.12951635180384, 1.0, threshold); + EXPECT_NEAR( mean(jlx[2])/0.124201140093879, 1.0, threshold); + EXPECT_NEAR( mean(jlx[3])/0.118268654505568, 1.0, threshold); + EXPECT_NEAR( mean(jlx[4])/0.0933871035384385, 1.0, threshold); + EXPECT_NEAR( mean(jlx[5])/0.0603800487910689, 1.0, threshold); + EXPECT_NEAR( mean(jlx[6])/0.0327117051555907, 1.0, threshold); + EXPECT_NEAR( mean(jlx[7])/0.0152155566653926, 1.0, threshold); +} + + +TEST(SphBessel,D2) +{ + int lmax = 7; + int rmesh = 700; + int kmesh = 800; + double dx = 0.0001; + + ModuleBase::Sph_Bessel_Recursive::D2 sphbesseld2; + sphbesseld2.set_dx(dx); + sphbesseld2.cal_jlx(lmax,rmesh,kmesh); + std::vector>> jlxd2 = sphbesseld2.get_jlx(); + std::vector> jlx(lmax+1); + + ASSERT_EQ(jlxd2.size(),static_cast(lmax + 1)); + + //calculate the mean of jlxd2[i][j] and assign to jlx[i][j] + for(int i=0; i u (da,db,dc); + ModuleBase::Vector3 up (u); + EXPECT_EQ(u.x,3.0); + EXPECT_EQ(u.y,4.0); + EXPECT_EQ(u.z,5.0); + EXPECT_EQ(up.x,3.0); + EXPECT_EQ(up.y,4.0); + EXPECT_EQ(up.z,5.0); + // float Vector3 + ModuleBase::Vector3 v (fa,fb,fc); + ModuleBase::Vector3 vp (v); + EXPECT_EQ(v.x,3.0); + EXPECT_EQ(v.y,4.0); + EXPECT_EQ(v.z,5.0); + EXPECT_EQ(vp.x,3.0); + EXPECT_EQ(vp.y,4.0); + EXPECT_EQ(vp.z,5.0); + // int Vector3 + ModuleBase::Vector3 w (ia,ib,ic); + ModuleBase::Vector3 wp (w); + EXPECT_EQ(w.x,3); + EXPECT_EQ(w.y,4); + EXPECT_EQ(w.z,5); + EXPECT_EQ(wp.x,3); + EXPECT_EQ(wp.y,4); + EXPECT_EQ(wp.z,5); +} + +TEST_F(Vector3Test,Set) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + EXPECT_EQ(u.x,3.0); + EXPECT_EQ(u.y,4.0); + EXPECT_EQ(u.z,5.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + EXPECT_EQ(v.x,3.0); + EXPECT_EQ(v.y,4.0); + EXPECT_EQ(v.z,5.0); + // int Vector3 + ModuleBase::Vector3 w; + w.set(ia,ib,ic); + EXPECT_EQ(w.x,3); + EXPECT_EQ(w.y,4); + EXPECT_EQ(w.z,5); +} + +TEST_F(Vector3Test,Equal) +{ + // double Vector3 + ModuleBase::Vector3 u, up; + u.set(da,db,dc); + up = u; + EXPECT_EQ(up.x,3.0); + EXPECT_EQ(up.y,4.0); + EXPECT_EQ(up.z,5.0); + // float Vector3 + ModuleBase::Vector3 v, vp; + v.set(fa,fb,fc); + vp = v; + EXPECT_EQ(vp.x,3.0); + EXPECT_EQ(vp.y,4.0); + EXPECT_EQ(vp.z,5.0); + // int Vector3 + ModuleBase::Vector3 w, wp; + w.set(ia,ib,ic); + wp = w; + EXPECT_EQ(wp.x,3); + EXPECT_EQ(wp.y,4); + EXPECT_EQ(wp.z,5); +} + +TEST_F(Vector3Test,PlusEqual) +{ + // double Vector3 + ModuleBase::Vector3 u, up; + u.set(da,db,dc); + up.set(da,db,dc); + up += u; + EXPECT_EQ(up.x,6.0); + EXPECT_EQ(up.y,8.0); + EXPECT_EQ(up.z,10.0); + // float Vector3 + ModuleBase::Vector3 v, vp; + v.set(fa,fb,fc); + vp.set(fa,fb,fc); + vp += v; + EXPECT_EQ(vp.x,6.0); + EXPECT_EQ(vp.y,8.0); + EXPECT_EQ(vp.z,10.0); + // int Vector3 + ModuleBase::Vector3 w, wp; + w.set(ia,ib,ic); + wp.set(ia,ib,ic); + wp += w; + EXPECT_EQ(wp.x,6); + EXPECT_EQ(wp.y,8); + EXPECT_EQ(wp.z,10); +} + +TEST_F(Vector3Test,MinusEqual) +{ + // double Vector3 + ModuleBase::Vector3 u, up; + u.set(da,db,dc); + up.set(3*da,3*db,3*dc); + up -= u; + EXPECT_EQ(up.x,6.0); + EXPECT_EQ(up.y,8.0); + EXPECT_EQ(up.z,10.0); + // float Vector3 + ModuleBase::Vector3 v, vp; + v.set(fa,fb,fc); + vp.set(3*fa,3*fb,3*fc); + vp -= v; + EXPECT_EQ(vp.x,6.0); + EXPECT_EQ(vp.y,8.0); + EXPECT_EQ(vp.z,10.0); + // int Vector3 + ModuleBase::Vector3 w, wp; + w.set(ia,ib,ic); + wp.set(3*ia,3*ib,3*ic); + wp -= w; + EXPECT_EQ(wp.x,6); + EXPECT_EQ(wp.y,8); + EXPECT_EQ(wp.z,10); +} + +TEST_F(Vector3Test,MultiplyEqual) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + u *= 2; + EXPECT_EQ(u.x,6.0); + EXPECT_EQ(u.y,8.0); + EXPECT_EQ(u.z,10.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + v *= 2; + EXPECT_EQ(v.x,6.0); + EXPECT_EQ(v.y,8.0); + EXPECT_EQ(v.z,10.0); + // int Vector3 + ModuleBase::Vector3 w; + w.set(ia,ib,ic); + w *= 2; + EXPECT_EQ(w.x,6); + EXPECT_EQ(w.y,8); + EXPECT_EQ(w.z,10); +} + +TEST_F(Vector3Test,OverEqual) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(4*da,4*db,4*dc); + u /= 2; + EXPECT_EQ(u.x,6.0); + EXPECT_EQ(u.y,8.0); + EXPECT_EQ(u.z,10.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(4*fa,4*fb,4*fc); + v /= 2; + EXPECT_EQ(v.x,6.0); + EXPECT_EQ(v.y,8.0); + EXPECT_EQ(v.z,10.0); + // int Vector3 + ModuleBase::Vector3 w; + w.set(4*ia,4*ib,4*ic); + w /= 2; + EXPECT_EQ(w.x,6); + EXPECT_EQ(w.y,8); + EXPECT_EQ(w.z,10); +} + +TEST_F(Vector3Test,Negative) +{ + // double Vector3 + ModuleBase::Vector3 u, up; + u.set(da,db,dc); + up = -u; + EXPECT_EQ(up.x,-3.0); + EXPECT_EQ(up.y,-4.0); + EXPECT_EQ(up.z,-5.0); + // float Vector3 + ModuleBase::Vector3 v, vp; + v.set(fa,fb,fc); + vp = -v; + EXPECT_EQ(vp.x,-3.0); + EXPECT_EQ(vp.y,-4.0); + EXPECT_EQ(vp.z,-5.0); + // int Vector3 + ModuleBase::Vector3 w, wp; + w.set(ia,ib,ic); + wp = -w; + EXPECT_EQ(wp.x,-3); + EXPECT_EQ(wp.y,-4); + EXPECT_EQ(wp.z,-5); +} + +TEST_F(Vector3Test,Access) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + EXPECT_EQ(u[0],3.0); + EXPECT_EQ(u[1],4.0); + EXPECT_EQ(u[2],5.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + EXPECT_EQ(v.x,3.0); + EXPECT_EQ(v.y,4.0); + EXPECT_EQ(v.z,5.0); + // int Vector3 + ModuleBase::Vector3 w; + w.set(ia,ib,ic); + EXPECT_EQ(w.x,3); + EXPECT_EQ(w.y,4); + EXPECT_EQ(w.z,5); +} + + +TEST_F(Vector3Test,ConstAccess) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + const ModuleBase::Vector3 *up(&u); + EXPECT_EQ((*up)[0],3.0); + EXPECT_EQ((*up)[1],4.0); + EXPECT_EQ((*up)[2],5.0); + // float Vector3 + ModuleBase::Vector3 v; + const ModuleBase::Vector3 *vp(&v); + v.set(fa,fb,fc); + EXPECT_EQ((*vp).x,3.0); + EXPECT_EQ((*vp).y,4.0); + EXPECT_EQ((*vp).z,5.0); + // int Vector3 + //ModuleBase::Vector3 w; + //w.set(ia,ib,ic); + //EXPECT_EQ(w.x,3); + //EXPECT_EQ(w.y,4); + //EXPECT_EQ(w.z,5); +} + + +TEST_F(Vector3Test,Reverse) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + u.reverse(); + EXPECT_EQ(u.x,-3.0); + EXPECT_EQ(u.y,-4.0); + EXPECT_EQ(u.z,-5.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + v.reverse(); + EXPECT_EQ(v.x,-3.0); + EXPECT_EQ(v.y,-4.0); + EXPECT_EQ(v.z,-5.0); + // int Vector3 + ModuleBase::Vector3 w; + w.set(ia,ib,ic); + w.reverse(); + EXPECT_EQ(w.x,-3); + EXPECT_EQ(w.y,-4); + EXPECT_EQ(w.z,-5); +} + +TEST_F(Vector3Test,VectorPlus) +{ + // double Vector3 + ModuleBase::Vector3 u,up,upp; + u.set(da,db,dc); + up.set(da,db,dc); + upp = u + up; + EXPECT_EQ(upp[0],6.0); + EXPECT_EQ(upp[1],8.0); + EXPECT_EQ(upp[2],10.0); + // float Vector3 + ModuleBase::Vector3 v,vp,vpp; + v.set(fa,fb,fc); + vp.set(fa,fb,fc); + vpp = v + vp; + EXPECT_EQ(vpp.x,6.0); + EXPECT_EQ(vpp.y,8.0); + EXPECT_EQ(vpp.z,10.0); + // int Vector3 + ModuleBase::Vector3 w,wp,wpp; + w.set(ia,ib,ic); + wp.set(ia,ib,ic); + wpp = w + wp; + EXPECT_EQ(wpp.x,6); + EXPECT_EQ(wpp.y,8); + EXPECT_EQ(wpp.z,10); +} + +TEST_F(Vector3Test,VectorMinus) +{ + // double Vector3 + ModuleBase::Vector3 u,up,upp; + u.set(da,db,dc); + up.set(2*da,2*db,2*dc); + upp = u - up; + EXPECT_EQ(upp[0],-3.0); + EXPECT_EQ(upp[1],-4.0); + EXPECT_EQ(upp[2],-5.0); + // float Vector3 + ModuleBase::Vector3 v,vp,vpp; + v.set(fa,fb,fc); + vp.set(3*fa,3*fb,3*fc); + vpp = v - vp; + EXPECT_EQ(vpp.x,-6.0); + EXPECT_EQ(vpp.y,-8.0); + EXPECT_EQ(vpp.z,-10.0); + // int Vector3 + ModuleBase::Vector3 w,wp,wpp; + w.set(3*ia,3*ib,3*ic); + wp.set(ia,ib,ic); + wpp = w - wp; + EXPECT_EQ(wpp.x,6); + EXPECT_EQ(wpp.y,8); + EXPECT_EQ(wpp.z,10); +} + +TEST_F(Vector3Test,Norm2) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + EXPECT_EQ(u.norm2(),50.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + EXPECT_EQ(v.norm2(),50.0); + // int Vector3 + ModuleBase::Vector3 w; + w.set(ia,ib,ic); + EXPECT_EQ(w.norm2(),50); +} + + +TEST_F(Vector3Test,Norm) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + double nm = u.norm(); + double nm2= sqrt(50.0); + EXPECT_DOUBLE_EQ(nm,nm2); + EXPECT_FLOAT_EQ(nm,sqrt(50.0)); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + float nmp = v.norm(); + float nmp2= sqrt(50.0); + EXPECT_FLOAT_EQ(nmp,sqrt(50.0)); +} + + +TEST_F(Vector3Test,Normalize) +{ + // double Vector3 + ModuleBase::Vector3 u; + u.set(da,db,dc); + u.normalize(); + EXPECT_DOUBLE_EQ(u.norm(),1.0); + // float Vector3 + ModuleBase::Vector3 v; + v.set(fa,fb,fc); + v.normalize(); + EXPECT_FLOAT_EQ(v.norm(),1.0); +} + +TEST_F(Vector3Test,VmultiplyV) +{ + // double Vector3 + ModuleBase::Vector3 u,up; + u.set(da,db,dc); + up.set(da,db,dc); + double mpd = u * up; + EXPECT_EQ(mpd,50.0); + // float Vector3 + ModuleBase::Vector3 v,vp; + v.set(fa,fb,fc); + vp.set(fa,fb,fc); + float mpf = v*vp; + EXPECT_EQ(mpf,50.0); + // int Vector3 + ModuleBase::Vector3 w,wp; + w.set(ia,ib,ic); + wp.set(ia,ib,ic); + int mpi = w*wp; + EXPECT_EQ(mpf,50); +} + +TEST_F(Vector3Test,VdotV) +{ + // double Vector3 + ModuleBase::Vector3 u,up; + u.set(da,db,dc); + up.set(da,db,dc); + double mpd = dot(u,up); + EXPECT_EQ(mpd,50.0); + // float Vector3 + ModuleBase::Vector3 v,vp; + v.set(fa,fb,fc); + vp.set(fa,fb,fc); + float mpf = dot(v,vp); + EXPECT_EQ(mpf,50.0); + // int Vector3 + ModuleBase::Vector3 w,wp; + w.set(ia,ib,ic); + wp.set(ia,ib,ic); + int mpi = dot(w,wp); + EXPECT_EQ(mpf,50); +} + +TEST_F(Vector3Test,VmultiplyNum) +{ + // double Vector3 + ModuleBase::Vector3 u,up,upp; + u.set(da,db,dc); + double s = 3.0; + up = s*u; upp = u*s; + EXPECT_EQ(upp[0],up[0]); + EXPECT_EQ(upp[1],up[1]); + EXPECT_EQ(upp[2],up[2]); + EXPECT_EQ(upp[0],9.0); + EXPECT_EQ(upp[1],12.0); + EXPECT_EQ(upp[2],15.0); + // float Vector3 + ModuleBase::Vector3 v,vp,vpp; + v.set(fa,fb,fc); + float t = 3.0; + vp = t*v; vpp = v*t; + EXPECT_EQ(vpp[0],vp[0]); + EXPECT_EQ(vpp[1],vp[1]); + EXPECT_EQ(vpp[2],vp[2]); + EXPECT_EQ(vpp[0],9.0); + EXPECT_EQ(vpp[1],12.0); + EXPECT_EQ(vpp[2],15.0); + // int Vector3 + ModuleBase::Vector3 w,wp,wpp; + w.set(ia,ib,ic); + int q = 3; + wp = q*w; wpp = w*q; + EXPECT_EQ(wpp[0],wp[0]); + EXPECT_EQ(wpp[1],wp[1]); + EXPECT_EQ(wpp[2],wp[2]); + EXPECT_EQ(wpp[0],9.0); + EXPECT_EQ(wpp[1],12.0); + EXPECT_EQ(wpp[2],15.0); +} + +TEST_F(Vector3Test,VoverNum) +{ + // double Vector3 + ModuleBase::Vector3 u,up; + u.set(2*da,2*db,2*dc); + double s = 2.0; + up = u/s; + EXPECT_EQ(up.x,3.0); + EXPECT_EQ(up.y,4.0); + EXPECT_EQ(up.z,5.0); + // float Vector3 + ModuleBase::Vector3 v,vp; + v.set(2*fa,2*fb,2*fc); + float t = 2.0; + vp = v/t; + EXPECT_EQ(vp.x,3.0); + EXPECT_EQ(vp.y,4.0); + EXPECT_EQ(vp.z,5.0); + // int Vector3 + ModuleBase::Vector3 w,wp; + w.set(2*ia,2*ib,2*ic); + int q = 2; + wp = w/q; + EXPECT_EQ(wp.x,3); + EXPECT_EQ(wp.y,4); + EXPECT_EQ(wp.z,5); +} + +TEST_F(Vector3Test,OperatorCaret) +{ + // double Vector3 + ModuleBase::Vector3 u,up,upp; + u.set(da,db,dc); + up.set(da,db,dc); + upp = u^up; + EXPECT_EQ(upp.x,u.y*up.z - u.z*up.y); + EXPECT_EQ(upp.y,u.z*up.x - u.x*up.z); + EXPECT_EQ(upp.z,u.x*up.y - u.y*up.x); + // float Vector3 + ModuleBase::Vector3 v,vp,vpp; + v.set(2*fa,2*fb,2*fc); + vp.set(fa,fb,fc); + vpp = v^vp; + EXPECT_EQ(vpp.x,v.y*vp.z - v.z*vp.y); + EXPECT_EQ(vpp.y,v.z*vp.x - v.x*vp.z); + EXPECT_EQ(vpp.z,v.x*vp.y - v.y*vp.x); + // int Vector3 + ModuleBase::Vector3 w,wp,wpp; + w.set(2*ia,2*ib,2*ic); + wp.set(ia,ib,ic); + wpp = w^wp; + EXPECT_EQ(wpp.x,w.y*wp.z - w.z*wp.y); + EXPECT_EQ(wpp.y,w.z*wp.x - w.x*wp.z); + EXPECT_EQ(wpp.z,w.x*wp.y - w.y*wp.x); +} + +TEST_F(Vector3Test,Cross) +{ + // double Vector3 + ModuleBase::Vector3 u,up,upp; + u.set(da,db,dc); + up.set(da,db,dc); + upp = cross(u,up); + EXPECT_EQ(upp.x,u.y*up.z - u.z*up.y); + EXPECT_EQ(upp.y,u.z*up.x - u.x*up.z); + EXPECT_EQ(upp.z,u.x*up.y - u.y*up.x); + // float Vector3 + ModuleBase::Vector3 v,vp,vpp; + v.set(2*fa,2*fb,2*fc); + vp.set(fa,fb,fc); + vpp = cross(v,vp); + EXPECT_EQ(vpp.x,v.y*vp.z - v.z*vp.y); + EXPECT_EQ(vpp.y,v.z*vp.x - v.x*vp.z); + EXPECT_EQ(vpp.z,v.x*vp.y - v.y*vp.x); + // int Vector3 + ModuleBase::Vector3 w,wp,wpp; + w.set(2*ia,2*ib,2*ic); + wp.set(ia,ib,ic); + wpp = cross(w,wp); + EXPECT_EQ(wpp.x,w.y*wp.z - w.z*wp.y); + EXPECT_EQ(wpp.y,w.z*wp.x - w.x*wp.z); + EXPECT_EQ(wpp.z,w.x*wp.y - w.y*wp.x); +} + +TEST_F(Vector3Test,VeqV) +{ + // double Vector3 + ModuleBase::Vector3 u,up; + u.set(da,db,dc); + up.set(da,db,dc); + EXPECT_TRUE(up == u); + // float Vector3 + ModuleBase::Vector3 v,vp; + v.set(fa,fb,fc); + vp.set(fa,fb,fc); + EXPECT_TRUE(vp == v); + // int Vector3 + ModuleBase::Vector3 w,wp; + w.set(ia,ib,ic); + wp.set(ia,ib,ic); + EXPECT_TRUE(wp == w); +} + +TEST_F(Vector3Test,VneV) +{ + // double Vector3 + ModuleBase::Vector3 u,up; + u.set(da,db,dc); + up.set(da,db,2*dc); + EXPECT_TRUE(up != u); + // float Vector3 + ModuleBase::Vector3 v,vp; + v.set(fa,fb,2*fc); + vp.set(fa,fb,fc); + EXPECT_TRUE(vp != v); + // int Vector3 + ModuleBase::Vector3 w,wp; + w.set(ia,ib,2*ic); + wp.set(ia,ib,ic); + EXPECT_TRUE(wp != w); +} + +TEST_F(Vector3Test,StdOutV) +{ + // double Vector3 + ModuleBase::Vector3 u(da,db,dc); + testing::internal::CaptureStdout(); + std::cout << u << std::endl; + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr("(")); + // float Vector3 + ModuleBase::Vector3 v(fa,fb,fc); + testing::internal::CaptureStdout(); + std::cout << v << std::endl; + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr(",")); + // int Vector3 + ModuleBase::Vector3 w(ia,ib,ic); + testing::internal::CaptureStdout(); + std::cout << w << std::endl; + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr(")")); +} + +TEST_F(Vector3Test,PrintV) +{ + // double Vector3 + ModuleBase::Vector3 u(3.1415926,db,dc); + testing::internal::CaptureStdout(); + u.print(); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr("3.1416")); + // float Vector3 + ModuleBase::Vector3 v(fa,fb,3.14); + testing::internal::CaptureStdout(); + v.print(); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr("3.14")); + // int Vector3 + ModuleBase::Vector3 w(ia,101,ic); + testing::internal::CaptureStdout(); + w.print(); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr("101")); +} + diff --git a/source/module_base/vector3.h b/source/module_base/vector3.h index f230f616f3c..94a09afcfdb 100644 --- a/source/module_base/vector3.h +++ b/source/module_base/vector3.h @@ -6,103 +6,372 @@ #endif #include -#include #include +#include namespace ModuleBase { -template -class Vector3 +/** + * @brief 3 elements vector + * + * @tparam T + */ +template class Vector3 { -public: - T x; - T y; - T z; - - Vector3(const T &x1 = 0,const T &y1 = 0,const T &z1 = 0) :x(x1),y(y1),z(z1){}; - Vector3(const Vector3 &v) :x(v.x),y(v.y),z(v.z){}; // Peize Lin add 2018-07-16 - void set(const T &x1, const T &y1,const T &z1) { x = x1; y = y1; z = z1; } - - Vector3& operator =(const Vector3 &u) { x=u.x; y=u.y; z=u.z; return *this; } - Vector3& operator+=(const Vector3 &u) { x+=u.x; y+=u.y; z+=u.z; return *this; } - Vector3& operator-=(const Vector3 &u) { x-=u.x; y-=u.y; z-=u.z; return *this; } - Vector3& operator*=(const Vector3 &u); - Vector3& operator*=(const T &s) { x*=s; y*=s; z*=s; return *this; } - Vector3& operator/=(const Vector3 &u); - Vector3& operator/=(const T &s) { x/=s; y/=s; z/=s; return *this; } - Vector3 operator -() const { return Vector3(-x,-y,-z); } // Peize Lin add 2017-01-10 - - T operator[](int index)const { return (&x)[index]; } - T& operator[](int index) { return (&x)[index]; } - - T norm2(void) const { return x*x + y*y + z*z; } - T norm(void) const { return sqrt(norm2()); } - Vector3& normalize(void){ const T m=norm(); x/=m; y/=m; z/=m; return *this; } // Peize Lin update return 2019-09-08 - Vector3& reverse(void){ x=-x; y=-y; z=-z; return *this; } // Peize Lin update return 2019-09-08 - - void print(void)const ; // mohan add 2009-11-29 + public: + T x; + T y; + T z; + + /** + * @brief Construct a new Vector 3 object + * + * @param x1 + * @param y1 + * @param z1 + */ + Vector3(const T &x1 = 0, const T &y1 = 0, const T &z1 = 0) : x(x1), y(y1), z(z1){}; + Vector3(const Vector3 &v) : x(v.x), y(v.y), z(v.z){}; // Peize Lin add 2018-07-16 + + /** + * @brief set a 3d vector + * + * @param x1 + * @param y1 + * @param z1 + */ + void set(const T &x1, const T &y1, const T &z1) + { + x = x1; + y = y1; + z = z1; + } + + /** + * @brief Overload operator "=" for Vector3 + * + * @param u + * @return Vector3& + */ + Vector3 &operator=(const Vector3 &u) + { + x = u.x; + y = u.y; + z = u.z; + return *this; + } + + /** + * @brief Overload operator "+=" for Vector3 + * + * @param u + * @return Vector3& + */ + Vector3 &operator+=(const Vector3 &u) + { + x += u.x; + y += u.y; + z += u.z; + return *this; + } + + /** + * @brief Overload operator "-=" for Vector3 + * + * @param u + * @return Vector3& + */ + Vector3 &operator-=(const Vector3 &u) + { + x -= u.x; + y -= u.y; + z -= u.z; + return *this; + } + + /** + * @brief Overload operator "*=" for (Vector3)*scalar + * + * @param s + * @return Vector3& + */ + Vector3 &operator*=(const T &s) + { + x *= s; + y *= s; + z *= s; + return *this; + } + + /** + * @brief Overload operator "/=" for (Vector3)/scalar + * + * @param s + * @return Vector3& + */ + Vector3 &operator/=(const T &s) + { + x /= s; + y /= s; + z /= s; + return *this; + } + + /** + * @brief Overload operator "-" to get (-Vector3) + * + * @return Vector3 + */ + Vector3 operator-() const + { + return Vector3(-x, -y, -z); + } // Peize Lin add 2017-01-10 + + /** + * @brief Over load "[]" for accessing elements with pointers + * + * @param index + * @return T + */ + T operator[](int index) const + { + return (&x)[index]; + } + + /** + * @brief Overload operator "[]" for accesing elements + * + * @param index + * @return T& + */ + T &operator[](int index) + { + return (&x)[index]; + } + + /** + * @brief Get the square of nomr of a Vector3 + * + * @return T + */ + T norm2(void) const + { + return x * x + y * y + z * z; + } + + /** + * @brief Get the norm of a Vector3 + * + * @return T + */ + T norm(void) const + { + return sqrt(norm2()); + } + + /** + * @brief Normalize a Vector3 + * + * @return Vector3& + */ + Vector3 &normalize(void) + { + const T m = norm(); + x /= m; + y /= m; + z /= m; + return *this; + } // Peize Lin update return 2019-09-08 + + /** + * @brief Get (-Vector3) + * + * @return Vector3& + */ + Vector3 &reverse(void) + { + x = -x; + y = -y; + z = -z; + return *this; + } // Peize Lin update return 2019-09-08 + + /** + * @brief Print a Vector3 on standard output + * with formats + * + */ + void print(void) const; // mohan add 2009-11-29 }; -template inline Vector3 operator+( const Vector3 &u, const Vector3 &v ) { return Vector3( u.x+v.x, u.y+v.y, u.z+v.z ); } -template inline Vector3 operator-( const Vector3 &u, const Vector3 &v ) { return Vector3( u.x-v.x, u.y-v.y, u.z-v.z ); } -//u.v=(ux*vx)+(uy*vy)+(uz*vz) -template inline T operator*( const Vector3 &u, const Vector3 &v ) { return ( u.x*v.x + u.y*v.y + u.z*v.z ); } -template inline Vector3 operator*( const T &s, const Vector3 &u ) { return Vector3( u.x*s, u.y*s, u.z*s ); } -template inline Vector3 operator*( const Vector3 &u, const T &s ) { return Vector3( u.x*s, u.y*s, u.z*s ); } // mohan add 2009-5-10 -template inline Vector3 operator/( const Vector3 &u, const T &s ) { return Vector3( u.x/s, u.y/s, u.z/s ); } -//u.v=(ux*vx)+(uy*vy)+(uz*vz) -template inline T dot ( const Vector3 &u, const Vector3 &v ) { return ( u.x*v.x + u.y*v.y + u.z*v.z ); } -// | i j k | -// | ux uy uz | -// | vx vy vz | -// u.v=(uy*vz-uz*vy)i+(-ux*vz+uz*vx)j+(ux*vy-uy*vx)k -template inline Vector3 operator^(const Vector3 &u,const Vector3 &v) -{ - return Vector3 ( u.y * v.z - u.z * v.y, - -u.x * v.z + u.z * v.x, - u.x * v.y - u.y * v.x); +/** + * @brief Overload "+" for two Vector3 + * + * @param[in] u + * @param[in] v + * @return Vector3 + */ +template inline Vector3 operator+(const Vector3 &u, const Vector3 &v) +{ + return Vector3(u.x + v.x, u.y + v.y, u.z + v.z); } -// | i j k | -// | ux uy uz | -// | vx vy vz | -// u.v=(uy*vz-uz*vy)i+(-ux*vz+uz*vx)j+(ux*vy-uy*vzx)k -template inline Vector3 cross(const Vector3 &u,const Vector3 &v) + +/** + * @brief Overload "-" for two Vector3 + * + * @param[in] u + * @param[in] v + * @return Vector3 + */ +template inline Vector3 operator-(const Vector3 &u, const Vector3 &v) +{ + return Vector3(u.x - v.x, u.y - v.y, u.z - v.z); +} + +/** + * @brief Overload "*" to calculate the dot product + * of two Vector3 + * + * @param u + * @param v + * @return template + */ +template inline T operator*(const Vector3 &u, const Vector3 &v) { - return Vector3 ( u.y * v.z - u.z * v.y, - -u.x * v.z + u.z * v.x, - u.x * v.y - u.y * v.x); + return (u.x * v.x + u.y * v.y + u.z * v.z); } -//s = u.(v x w) -//template T TripleScalarProduct(Vector3 u, Vector3 v, Vector3 w) + +/** + * @brief Overload "*" to calculate (Vector3)*scalar + * + * @param[in] s + * @param[in] u + * @return Vector3 + */ +template inline Vector3 operator*(const T &s, const Vector3 &u) +{ + return Vector3(u.x * s, u.y * s, u.z * s); +} + +/** + * @brief Overload "*" to calculate scalar*(Vector3) + * + * @param u + * @param s + * @return Vector3 + */ +template inline Vector3 operator*(const Vector3 &u, const T &s) +{ + return Vector3(u.x * s, u.y * s, u.z * s); +} // mohan add 2009-5-10 + +/** + * @brief Overload "/" to calculate Vector3/scalar + * + * @tparam T + * @param u + * @param s + * @return Vector3 + */ +template inline Vector3 operator/(const Vector3 &u, const T &s) +{ + return Vector3(u.x / s, u.y / s, u.z / s); +} + +/** + * @brief Dot productor of two Vector3 + * + * @param u + * @param v + * @return T + * @note u.v=(ux*vx)+(uy*vy)+(uz*vz) + */ +template inline T dot(const Vector3 &u, const Vector3 &v) +{ + return (u.x * v.x + u.y * v.y + u.z * v.z); +} + +/** + * @brief Overload "^" for cross product of two Vector3 + * + * @param u + * @param v + * @return template + * @note + * | i j k | + * | ux uy uz | + * | vx vy vz | + * u.v=(uy*vz-uz*vy)i+(-ux*vz+uz*vx)j+(ux*vy-uy*vx)k + */ +template inline Vector3 operator^(const Vector3 &u, const Vector3 &v) +{ + return Vector3(u.y * v.z - u.z * v.y, -u.x * v.z + u.z * v.x, u.x * v.y - u.y * v.x); +} + +/** + * @brief Cross product of two Vector3 + * + * @param u + * @param v + * @return template + * @note + * | i j k | + * | ux uy uz | + * | vx vy vz | + * u.v=(uy*vz-uz*vy)i+(-ux*vz+uz*vx)j+(ux*vy-uy*vx)k + */ +template inline Vector3 cross(const Vector3 &u, const Vector3 &v) +{ + return Vector3(u.y * v.z - u.z * v.y, -u.x * v.z + u.z * v.x, u.x * v.y - u.y * v.x); +} +// s = u.(v x w) +// template T TripleScalarProduct(Vector3 u, Vector3 v, Vector3 w) //{ // return T((u.x * (v.y * w.z - v.z * w.y)) + // (u.y * (-v.x * w.z + v.z * w.x)) + // (u.z * (v.x * w.y - v.y * w.x))); -//} +// } -//whether m1 != m2 -template inline bool operator !=(const Vector3 &u, const Vector3 &v){ return !(u == v); } -//whether u == v -template inline bool operator ==(const Vector3 &u, const Vector3 &v) +// whether m1 != m2 +template inline bool operator!=(const Vector3 &u, const Vector3 &v) { - if(u.x == v.x && u.y == v.y && u.z == v.z) - return true; - return false; + return !(u == v); } - -template void Vector3::print(void)const +// whether u == v +template inline bool operator==(const Vector3 &u, const Vector3 &v) { - std::cout.precision(5) ; - std::cout << "(" << std::setw(10) << x << "," << std::setw(10) << y << "," - << std::setw(10) << z << ")" << std::endl ; - return ; + if (u.x == v.x && u.y == v.y && u.z == v.z) + return true; + return false; } -template static std::ostream & operator << ( std::ostream &os, const Vector3 &u ) + +/** + * @brief Print a Vector3 on standard output with formats + * + */ +template void Vector3::print(void) const { - os << "(" << std::setw(10) << u.x << "," << std::setw(10) << u.y << "," << std::setw(10) << u.z << ")"; - return os; + std::cout.precision(5); + std::cout << "(" << std::setw(10) << x << "," << std::setw(10) << y << "," << std::setw(10) << z << ")" + << std::endl; + return; } +/** + * @brief Overload "<<" tor print out a + * Vector3 on standard output + * + * @tparam T + * @param[in] os + * @param[in] u + * @return std::ostream& + */ +template static std::ostream &operator<<(std::ostream &os, const Vector3 &u) +{ + os << "(" << std::setw(10) << u.x << "," << std::setw(10) << u.y << "," << std::setw(10) << u.z << ")"; + return os; } +} // namespace ModuleBase + #endif diff --git a/source/module_base/ylm.h b/source/module_base/ylm.h index c4b9c7b93e9..5cb98085f48 100644 --- a/source/module_base/ylm.h +++ b/source/module_base/ylm.h @@ -15,37 +15,76 @@ class Ylm static int nlm; - // (1) for check + + /** + * @brief Get the ylm real object + * + * @param Lmax [in] maximum angular quantum number + 1 + * @param vec [in] the vector to be calculated + * @param ylmr [out] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + */ static void get_ylm_real( const int &Lmax , const ModuleBase::Vector3 &vec, double ylmr[]); - // (2) for check + /** + * @brief Get the ylm real object and the gradient + * + * @param Lmax [in] maximum angular quantum number + l + * @param vec [in] the vector to be calculated + * @param ylmr [out] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + * @param dylmdr [out] gradient of Ylm, [dY00/dx, dY00/dy, dY00/dz], [dY10/dx, dY10/dy, dY10/dz], [dY11/dx, dY11/dy, dY11/dz],... + */ static void get_ylm_real( const int &Lmax , const ModuleBase::Vector3 &vec, double ylmr[], double dylmdr[][3]); - // (3) not used anymore. + /** + * @brief Get the ylm real (solid) object (not used anymore) + * + * @param Lmax [in] maximum angular quantum number + l + * @param x [in] x + * @param y [in] y + * @param z [in] z + * @param rly [in] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + */ static void rlylm( - const int& Lmax, // max momentum of l +1 + const int& Lmax, const double& x, const double& y, const double& z, double rly[]); - // (4) not used anymore. + /** + * @brief Get the ylm real (solid) object and the gradient (not used anymore) + * + * @param Lmax [in] maximum angular quantum number + 1 + * @param x [in] x + * @param y [in] y + * @param z [in] z + * @param rly [in] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + * @param grly [out] gradient of Ylm, [dY00/dx, dY00/dy, dY00/dz], [dY10/dx, dY10/dy, dY10/dz], [dY11/dx, dY11/dy, dY11/dz],... + */ static void rlylm( - const int& Lmax, // max momentum of l +1 + const int& Lmax, const double& x, const double& y, const double& z, double rly[], double grly[][3]); - - // (5) used in grid integration. + + /** + * @brief Get the ylm real object (used in grid integration) + * + * @param Lmax [in] maximum angular quantum number + * @param xdr [in] x/r + * @param ydr [in] y/r + * @param zdr [in] z/r + * @param rly [in] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + */ static void sph_harm( const int& Lmax, const double& xdr, @@ -53,8 +92,17 @@ class Ylm const double& zdr, std::vector &rly); - // (6) used in getting overlap. - // Peize Lin change rly 2016-08-26 + /** + * @brief Get the ylm real object (used in getting overlap) + * + * @param Lmax [in] maximum angular quantum number + * @param x [in] x/r + * @param y [in] y/r + * @param z [in] z/r + * @param rly [in] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + * @author Peize Lin + * @date 2016-08-26 + */ static void rl_sph_harm( const int& Lmax, const double& x, @@ -62,8 +110,16 @@ class Ylm const double& z, std::vector& rly); - // (6) used in getting derivative of overlap. - // Peize Lin change rly, grly 2016-08-26 + /** + * @brief Get the ylm real object and the gradient (used in getting derivative of overlap) + * + * @param Lmax [in] maximum angular quantum number + * @param x [in] x/r + * @param y [in] y/r + * @param z [in] z/r + * @param rly [in] calculated Ylm, Y00, Y10, Y11, Y1-1, Y20, Y21, Y2-1, Y22, Y2-2... + * @param grly [out] gradient of Ylm, [dY00/dx, dY00/dy, dY00/dz], [dY10/dx, dY10/dy, dY10/dz], [dY11/dx, dY11/dy, dY11/dz],... + */ static void grad_rl_sph_harm( const int& Lmax, const double& x, @@ -71,14 +127,15 @@ class Ylm const double& z, std::vector& rly, std::vector>& grly); - + + //calculate the coefficient of Ylm, ylmcoef. static void set_coefficients (); static std::vector ylmcoef; static void test(); static void test1(); static void test2(); - + //set the first n elements of u to be 0.0 static void ZEROS(double u[], const int& n); private: diff --git a/source/module_ensolver/CMakeLists.txt b/source/module_ensolver/CMakeLists.txt new file mode 100644 index 00000000000..f24a61914ab --- /dev/null +++ b/source/module_ensolver/CMakeLists.txt @@ -0,0 +1,11 @@ +add_library( + en_solver + OBJECT + en_solver.cpp + FP/ab_initio.cpp + FP/KSDFT/ks_scf.cpp + FP/KSDFT/PW/ks_scf_pw.cpp + FP/KSDFT/LCAO/ks_scf_lcao.cpp +) + + diff --git a/source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.cpp b/source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.cpp new file mode 100644 index 00000000000..3aad41d62ca --- /dev/null +++ b/source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.cpp @@ -0,0 +1,136 @@ +#include "ks_scf_lcao.h" + +//--------------temporary---------------------------- +#include "../../../../src_pw/global.h" +#include "../../../../module_base/global_function.h" +#include "src_io/print_info.h" + +#ifdef __DEEPKS +#include "module_deepks/LCAO_deepks.h" +#endif +//-----force------------------- + +//-----stress------------------ + +//--------------------------------------------------- + +namespace ModuleEnSover +{ + +void KS_SCF_LCAO::Init(Input &inp, UnitCell_pseudo &ucell) +{ + + // setup GlobalV::NBANDS + // Yu Liu add 2021-07-03 + GlobalC::CHR.cal_nelec(); + + // mohan add 2010-09-06 + // Yu Liu move here 2021-06-27 + // because the number of element type + // will easily be ignored, so here + // I warn the user again for each type. + for(int it=0; itLOE.solve_elec_stru(istep, ra, loc, lowf, uhm); + return ; +} + +void KS_SCF_LCAO::cal_Energy(energy &en) +{ + +} + +void KS_SCF_LCAO::cal_Force(ModuleBase::matrix &force) +{ + +} +void KS_SCF_LCAO::cal_Stress(ModuleBase::matrix &stress) +{ + +} + +} \ No newline at end of file diff --git a/source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.h b/source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.h new file mode 100644 index 00000000000..0f9d36d0785 --- /dev/null +++ b/source/module_ensolver/FP/KSDFT/LCAO/ks_scf_lcao.h @@ -0,0 +1,39 @@ +#ifndef KS_SCF_LCAO_H +#define KS_SCF_LCAO_H +#include "../ks_scf.h" + +#include "src_lcao/LOOP_elec.h" + +namespace ModuleEnSover +{ + +class KS_SCF_LCAO: public KS_SCF +{ +public: + KS_SCF_LCAO() + { + tag = "KS_SCF_LCAO"; + } + void Init(Input& inp, UnitCell_pseudo& cell) override; + + void Run(int istep, + Record_adj& ra, + Local_Orbital_Charge& loc, + Local_Orbital_wfc& lowf, + LCAO_Hamilt& uhm) override; + void Run(int istep, UnitCell_pseudo& cell) override {}; + + void cal_Energy(energy& en) override; + void cal_Force(ModuleBase::matrix &force) override; + void cal_Stress(ModuleBase::matrix& stress) override; + +private: + LOOP_elec LOE; +}; + +///Basis_lcao +///ORB_control orb_con; + + +} +#endif \ No newline at end of file diff --git a/source/module_ensolver/FP/KSDFT/PW/ks_scf_pw.cpp b/source/module_ensolver/FP/KSDFT/PW/ks_scf_pw.cpp new file mode 100644 index 00000000000..08c90207eb3 --- /dev/null +++ b/source/module_ensolver/FP/KSDFT/PW/ks_scf_pw.cpp @@ -0,0 +1,190 @@ +#include "ks_scf_pw.h" + +//--------------temporary---------------------------- +#include "../../../../src_pw/global.h" +#include "../../../../module_base/global_function.h" +#include "../../../../module_symmetry/symmetry.h" +#include "../../../../src_pw/vdwd2.h" +#include "../../../../src_pw/vdwd3.h" +#include "../../../../src_pw/vdwd2_parameters.h" +#include "../../../../src_pw/vdwd3_parameters.h" +#include "../../../../src_pw/pw_complement.h" +#include "../../../../src_pw/pw_basis.h" +#include "../../../../src_io/print_info.h" +//-----force------------------- +#include "../../../../src_pw/forces.h" +//-----stress------------------ +#include "../../../../src_pw/stress_pw.h" +//--------------------------------------------------- + +namespace ModuleEnSover +{ + +void KS_SCF_PW::Init(Input &inp, UnitCell_pseudo &ucell) +{ + // setup GlobalV::NBANDS + // Yu Liu add 2021-07-03 + GlobalC::CHR.cal_nelec(); + + // mohan add 2010-09-06 + // Yu Liu move here 2021-06-27 + // because the number of element type + // will easily be ignored, so here + // I warn the user again for each type. + for(int it=0; itph2e=new H2E_SDFT(); + // } + // else + // { + // this->ph2e=new H2E_PW(); + // } + // this->pes= new Estate_PW(); + // this->phamilt=new Hamilt_PW(); + // } + + // Basis_PW basis_pw; + // Init(Inputs &inp, Cell &cel) + // { + + // basis_pw.init(inp, cel); + + // pes->init(inp, cel, basis_pw); + + // phamilt->init(bas); + // phamilt->initpot(cel, pes); + + // ph2e->init(h, pes); + // } +}; +} +#endif \ No newline at end of file diff --git a/source/module_ensolver/FP/KSDFT/ks_scf.cpp b/source/module_ensolver/FP/KSDFT/ks_scf.cpp new file mode 100644 index 00000000000..e69de29bb2d diff --git a/source/module_ensolver/FP/KSDFT/ks_scf.h b/source/module_ensolver/FP/KSDFT/ks_scf.h new file mode 100644 index 00000000000..160eca6b746 --- /dev/null +++ b/source/module_ensolver/FP/KSDFT/ks_scf.h @@ -0,0 +1,17 @@ +#ifndef KS_SCF_H +#define KS_SCF_H +#include "../ab_initio.h" +// #include "estates.h" +// #include "h2e.h" +namespace ModuleEnSover +{ + +class KS_SCF: public ab_initio +{ + public: + // Psi *p_wf; + // Estate *p_es; + // H2E *p_h2e; +}; +} +#endif \ No newline at end of file diff --git a/source/module_ensolver/FP/OFDFT/ofdft.h b/source/module_ensolver/FP/OFDFT/ofdft.h new file mode 100644 index 00000000000..39b2fae4fa1 --- /dev/null +++ b/source/module_ensolver/FP/OFDFT/ofdft.h @@ -0,0 +1,8 @@ +#include "../ab_initio.h" +namespace ModuleEnSover +{ +class OFDFT: public ab_initio +{ + +}; +} \ No newline at end of file diff --git a/source/module_ensolver/FP/ab_initio.cpp b/source/module_ensolver/FP/ab_initio.cpp new file mode 100644 index 00000000000..e69de29bb2d diff --git a/source/module_ensolver/FP/ab_initio.h b/source/module_ensolver/FP/ab_initio.h new file mode 100644 index 00000000000..92f371bfc64 --- /dev/null +++ b/source/module_ensolver/FP/ab_initio.h @@ -0,0 +1,14 @@ +#ifndef AB_INITIO_H +#define AB_INITIO_H +#include "../en_solver.h" +// #include "hamilt.h" +namespace ModuleEnSover +{ +class ab_initio: public En_Solver +{ + public: + // Hamilt* phamilt; +}; +} + +#endif \ No newline at end of file diff --git a/source/module_ensolver/Makefile.ensolver b/source/module_ensolver/Makefile.ensolver new file mode 100644 index 00000000000..bec2e80b169 --- /dev/null +++ b/source/module_ensolver/Makefile.ensolver @@ -0,0 +1,15 @@ +VPATH:=$(VPATH)\ +:./module_ensolver\ +:./module_ensolver/FP\ +:./module_ensolver/FP/KSDFT\ +:./module_ensolver/FP/KSDFT/PW\ +:./module_ensolver/FP/KSDFT/LCAO\ + +OBJS_ENSOLVER=en_solver.o\ +ab_initio.o\ +ks_scf.o\ +ks_scf_pw.o\ +ks_scf_lcao.o + +OBJS_FIRST_PRINCIPLES:= ${OBJS_FIRST_PRINCIPLES} \ +${OBJS_ENSOLVER} \ No newline at end of file diff --git a/source/module_ensolver/en_solver.cpp b/source/module_ensolver/en_solver.cpp new file mode 100644 index 00000000000..0457f94a241 --- /dev/null +++ b/source/module_ensolver/en_solver.cpp @@ -0,0 +1,43 @@ +#include "en_solver.h" +#include "FP/KSDFT/PW/ks_scf_pw.h" +#include "FP/KSDFT/LCAO/ks_scf_lcao.h" +#include "FP/OFDFT/ofdft.h" +#include "stdio.h" +namespace ModuleEnSover +{ +void En_Solver:: printag() +{ + std::cout<* vel); static void force_virial( + ModuleEnSover::En_Solver *p_ensolver, const int &istep, const MD_parameters &mdp, const UnitCell_pseudo &unit_in, diff --git a/source/module_md/MSST.cpp b/source/module_md/MSST.cpp index cc0f79a782d..b339393950b 100644 --- a/source/module_md/MSST.cpp +++ b/source/module_md/MSST.cpp @@ -4,6 +4,7 @@ #include "mpi.h" #endif #include "../module_base/timer.h" +#include "module_ensolver/en_solver.h" MSST::MSST(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in) : Verlet(MD_para_in, unit_in) { @@ -32,14 +33,14 @@ MSST::~MSST() delete []old_v; } -void MSST::setup() +void MSST::setup(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("MSST", "setup"); ModuleBase::timer::tick("MSST", "setup"); int sd = mdp.direction; - MD_func::force_virial(step_, mdp, ucell, potential, force, virial); + MD_func::force_virial(p_ensolver, step_, mdp, ucell, potential, force, virial); MD_func::kinetic_stress(ucell, vel, allmass, kinetic, stress); stress += virial; diff --git a/source/module_md/MSST.h b/source/module_md/MSST.h index dfee7a58205..36f17c0dad1 100644 --- a/source/module_md/MSST.h +++ b/source/module_md/MSST.h @@ -9,7 +9,7 @@ class MSST : public Verlet MSST(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in); ~MSST(); - void setup(); + void setup(ModuleEnSover::En_Solver *p_ensolve); void first_half(); void second_half(); void outputMD(); diff --git a/source/module_md/NVE.cpp b/source/module_md/NVE.cpp index 6bfb8aab0e3..b353b4ceccd 100644 --- a/source/module_md/NVE.cpp +++ b/source/module_md/NVE.cpp @@ -6,12 +6,12 @@ NVE::NVE(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in) : Verlet(MD_para_i NVE::~NVE(){} -void NVE::setup() +void NVE::setup(ModuleEnSover::En_Solver *p_ensolve) { ModuleBase::TITLE("NVE", "setup"); ModuleBase::timer::tick("NVE", "setup"); - Verlet::setup(); + Verlet::setup(p_ensolve); ModuleBase::timer::tick("NVE", "setup"); } diff --git a/source/module_md/NVE.h b/source/module_md/NVE.h index a71339ffebd..b3ac6bfeb37 100644 --- a/source/module_md/NVE.h +++ b/source/module_md/NVE.h @@ -9,7 +9,7 @@ class NVE : public Verlet NVE(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in); ~NVE(); - void setup(); + void setup(ModuleEnSover::En_Solver *p_ensolve); void first_half(); void second_half(); void outputMD(); diff --git a/source/module_md/NVT_ADS.cpp b/source/module_md/NVT_ADS.cpp index d891eecbcdb..fbe79aadd07 100644 --- a/source/module_md/NVT_ADS.cpp +++ b/source/module_md/NVT_ADS.cpp @@ -15,12 +15,12 @@ NVT_ADS::NVT_ADS(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in) : Verlet(M NVT_ADS::~NVT_ADS(){} -void NVT_ADS::setup() +void NVT_ADS::setup(ModuleEnSover::En_Solver *p_ensolve) { ModuleBase::TITLE("NVT_ADS", "setup"); ModuleBase::timer::tick("NVT_ADS", "setup"); - Verlet::setup(); + Verlet::setup(p_ensolve); ModuleBase::timer::tick("NVT_ADS", "setup"); } diff --git a/source/module_md/NVT_ADS.h b/source/module_md/NVT_ADS.h index 8870065d80f..b5b01bccbe1 100644 --- a/source/module_md/NVT_ADS.h +++ b/source/module_md/NVT_ADS.h @@ -9,7 +9,7 @@ class NVT_ADS : public Verlet NVT_ADS(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in); ~NVT_ADS(); - void setup(); + void setup(ModuleEnSover::En_Solver *p_ensolve); void first_half(); void second_half(); void outputMD(); diff --git a/source/module_md/NVT_NHC.cpp b/source/module_md/NVT_NHC.cpp index 53bd7dc7194..96f02d1f3f1 100644 --- a/source/module_md/NVT_NHC.cpp +++ b/source/module_md/NVT_NHC.cpp @@ -46,12 +46,12 @@ NVT_NHC::~NVT_NHC() delete []veta; } -void NVT_NHC::setup() +void NVT_NHC::setup(ModuleEnSover::En_Solver *p_ensolve) { ModuleBase::TITLE("NVT_NHC", "setup"); ModuleBase::timer::tick("NVT_NHC", "setup"); - Verlet::setup(); + Verlet::setup(p_ensolve); temp_target(); diff --git a/source/module_md/NVT_NHC.h b/source/module_md/NVT_NHC.h index c0bf3d51781..e38b3a3b64f 100644 --- a/source/module_md/NVT_NHC.h +++ b/source/module_md/NVT_NHC.h @@ -9,7 +9,7 @@ class NVT_NHC : public Verlet NVT_NHC(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in); ~NVT_NHC(); - void setup(); + void setup(ModuleEnSover::En_Solver *p_ensolve); void first_half(); void second_half(); void outputMD(); diff --git a/source/module_md/run_md_classic.cpp b/source/module_md/run_md_classic.cpp index 3e693976b9c..0ecc9521a3c 100644 --- a/source/module_md/run_md_classic.cpp +++ b/source/module_md/run_md_classic.cpp @@ -9,6 +9,7 @@ #include "../input.h" #include "../src_io/print_info.h" #include "../module_base/timer.h" +#include "module_ensolver/en_solver.h" Run_MD_CLASSIC::Run_MD_CLASSIC(){} @@ -18,6 +19,7 @@ void Run_MD_CLASSIC::classic_md_line(void) { ModuleBase::TITLE("Run_MD_CLASSIC", "classic_md_line"); ModuleBase::timer::tick("Run_MD_CLASSIC", "classic_md_line"); + ModuleEnSover::En_Solver* p_ensolver; //qianrui add it temporarily // Setup the unitcell. #ifdef __LCAO @@ -59,14 +61,14 @@ void Run_MD_CLASSIC::classic_md_line(void) { if(verlet->step_ == 0) { - verlet->setup(); + verlet->setup(p_ensolver); } else { verlet->first_half(); // update force and virial due to the update of atom positions - MD_func::force_virial(verlet->step_, verlet->mdp, verlet->ucell, verlet->potential, verlet->force, verlet->virial); + MD_func::force_virial(p_ensolver, verlet->step_, verlet->mdp, verlet->ucell, verlet->potential, verlet->force, verlet->virial); verlet->second_half(); diff --git a/source/module_md/test/CMakeLists.txt b/source/module_md/test/CMakeLists.txt index 3458a0e6e79..4304377ddc2 100644 --- a/source/module_md/test/CMakeLists.txt +++ b/source/module_md/test/CMakeLists.txt @@ -42,18 +42,21 @@ list(APPEND depend_files AddTest( TARGET md_LJ_pot + LIBS ${math_libs} SOURCES LJ_pot_test.cpp ${depend_files} ) AddTest( TARGET md_func + LIBS ${math_libs} SOURCES MD_func_test.cpp ${depend_files} ) AddTest( TARGET md_fire + LIBS ${math_libs} SOURCES FIRE_test.cpp ../verlet.cpp ../FIRE.cpp @@ -62,6 +65,7 @@ AddTest( AddTest( TARGET md_nve + LIBS ${math_libs} SOURCES NVE_test.cpp ../verlet.cpp ../NVE.cpp @@ -70,6 +74,7 @@ AddTest( AddTest( TARGET md_nvt_ads + LIBS ${math_libs} SOURCES NVT_ADS_test.cpp ../verlet.cpp ../NVT_ADS.cpp @@ -78,6 +83,7 @@ AddTest( AddTest( TARGET md_nvt_nhc + LIBS ${math_libs} SOURCES NVT_NHC_test.cpp ../verlet.cpp ../NVT_NHC.cpp @@ -86,6 +92,7 @@ AddTest( AddTest( TARGET md_msst + LIBS ${math_libs} SOURCES MSST_test.cpp ../verlet.cpp ../MSST.cpp @@ -94,6 +101,7 @@ AddTest( AddTest( TARGET md_lgv + LIBS ${math_libs} SOURCES Langevin_test.cpp ../verlet.cpp ../Langevin.cpp diff --git a/source/module_md/test/FIRE_test.cpp b/source/module_md/test/FIRE_test.cpp index 07e875a0bf9..5f07e3e74a4 100644 --- a/source/module_md/test/FIRE_test.cpp +++ b/source/module_md/test/FIRE_test.cpp @@ -14,7 +14,8 @@ class FIRE_test : public testing::Test Setcell::parameters(); verlet = new FIRE(INPUT.mdp, ucell); - verlet->setup(); + ModuleEnSover::En_Solver *p_ensolver; + verlet->setup(p_ensolver); } void TearDown() diff --git a/source/module_md/test/Langevin_test.cpp b/source/module_md/test/Langevin_test.cpp index 0c24f0ad5ac..c124f7ced25 100644 --- a/source/module_md/test/Langevin_test.cpp +++ b/source/module_md/test/Langevin_test.cpp @@ -14,7 +14,8 @@ class Langevin_test : public testing::Test Setcell::parameters(); verlet = new Langevin(INPUT.mdp, ucell); - verlet->setup(); + ModuleEnSover::En_Solver *p_ensolver; + verlet->setup(p_ensolver); } void TearDown() diff --git a/source/module_md/test/MSST_test.cpp b/source/module_md/test/MSST_test.cpp index 158492b5b6e..50f9bf5c430 100644 --- a/source/module_md/test/MSST_test.cpp +++ b/source/module_md/test/MSST_test.cpp @@ -14,7 +14,8 @@ class MSST_test : public testing::Test Setcell::parameters(); verlet = new MSST(INPUT.mdp, ucell); - verlet->setup(); + ModuleEnSover::En_Solver *p_ensolver; + verlet->setup(p_ensolver); } void TearDown() diff --git a/source/module_md/test/NVE_test.cpp b/source/module_md/test/NVE_test.cpp index f485a98f78f..4276be64e46 100644 --- a/source/module_md/test/NVE_test.cpp +++ b/source/module_md/test/NVE_test.cpp @@ -1,6 +1,7 @@ #include "gtest/gtest.h" #include "setcell.h" #include "module_md/NVE.h" +#include "../../module_ensolver/en_solver.h" class NVE_test : public testing::Test { @@ -14,8 +15,9 @@ class NVE_test : public testing::Test Setcell::setupcell(ucell); Setcell::parameters(); verlet = new NVE(INPUT.mdp, ucell); + ModuleEnSover::En_Solver *p_ensolver; - verlet->setup(); + verlet->setup(p_ensolver); } void TearDown() diff --git a/source/module_md/test/NVT_ADS_test.cpp b/source/module_md/test/NVT_ADS_test.cpp index 461bdc6ca49..c65b675152a 100644 --- a/source/module_md/test/NVT_ADS_test.cpp +++ b/source/module_md/test/NVT_ADS_test.cpp @@ -14,7 +14,8 @@ class NVT_ADS_test : public testing::Test Setcell::parameters(); verlet = new NVT_ADS(INPUT.mdp, ucell); - verlet->setup(); + ModuleEnSover::En_Solver *p_ensolver; + verlet->setup(p_ensolver); } void TearDown() diff --git a/source/module_md/test/NVT_NHC_test.cpp b/source/module_md/test/NVT_NHC_test.cpp index cdaa760480e..ea502208a81 100644 --- a/source/module_md/test/NVT_NHC_test.cpp +++ b/source/module_md/test/NVT_NHC_test.cpp @@ -14,7 +14,8 @@ class NVT_NHC_test : public testing::Test Setcell::parameters(); verlet = new NVT_NHC(INPUT.mdp, ucell); - verlet->setup(); + ModuleEnSover::En_Solver *p_ensolver; + verlet->setup(p_ensolver); } void TearDown() diff --git a/source/module_md/verlet.cpp b/source/module_md/verlet.cpp index 239a6a13c34..5d8a2c3be09 100644 --- a/source/module_md/verlet.cpp +++ b/source/module_md/verlet.cpp @@ -4,6 +4,7 @@ #include "mpi.h" #endif #include "../module_base/timer.h" +#include "module_ensolver/en_solver.h" Verlet::Verlet(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in): mdp(MD_para_in), @@ -46,14 +47,14 @@ Verlet::~Verlet() delete []force; } -void Verlet::setup() +void Verlet::setup(ModuleEnSover::En_Solver *p_ensolver) { if(mdp.rstMD) { restart(); } - MD_func::force_virial(step_, mdp, ucell, potential, force, virial); + MD_func::force_virial(p_ensolver, step_, mdp, ucell, potential, force, virial); MD_func::kinetic_stress(ucell, vel, allmass, kinetic, stress); stress += virial; diff --git a/source/module_md/verlet.h b/source/module_md/verlet.h index bdca2da825f..b96e1d20565 100644 --- a/source/module_md/verlet.h +++ b/source/module_md/verlet.h @@ -4,6 +4,7 @@ #include "MD_parameters.h" #include "../module_cell/unitcell_pseudo.h" #include "../module_base/matrix.h" +#include "module_ensolver/en_solver.h" class Verlet { @@ -11,7 +12,7 @@ class Verlet Verlet(MD_parameters& MD_para_in, UnitCell_pseudo &unit_in); virtual ~Verlet(); - virtual void setup(); + virtual void setup(ModuleEnSover::En_Solver *p_ensolver); virtual void first_half(); virtual void second_half(); virtual void outputMD(); diff --git a/source/module_orbital/CMakeLists.txt b/source/module_orbital/CMakeLists.txt index 3a753030426..f5e4157844e 100644 --- a/source/module_orbital/CMakeLists.txt +++ b/source/module_orbital/CMakeLists.txt @@ -12,4 +12,11 @@ add_library( ORB_table_alpha.cpp ORB_table_beta.cpp ORB_table_phi.cpp + parallel_orbitals.cpp ) + +set(CMAKE_CXX_STANDARD_REQUIRED ON) +IF (BUILD_TESTING) + set(CMAKE_CXX_STANDARD 14) + add_subdirectory(test) +endif() \ No newline at end of file diff --git a/source/module_orbital/ORB_control.cpp b/source/module_orbital/ORB_control.cpp index 3ea5a898977..c77e3f3f4bd 100644 --- a/source/module_orbital/ORB_control.cpp +++ b/source/module_orbital/ORB_control.cpp @@ -1,11 +1,44 @@ #include "ORB_control.h" #include "ORB_gen_tables.h" #include "../module_base/timer.h" +#include "../src_parallel/parallel_common.h" +#include "../src_io/wf_local.h" +#include "../module_base/lapack_connector.h" +#include "../module_base/memory.h" + //#include "build_st_pw.h" -ORB_control::ORB_control() -{} +ORB_control::ORB_control( + const bool& gamma_only_in, + const int& nlocal_in, + const int& nbands_in, + const int& nspin_in, + const int& dsize_in, + const int& nb2d_in, + const int& dcolor_in, + const int& drank_in, + const int& myrank_in, + const std::string& calculation_in, + const std::string& ks_solver_in) : + gamma_only(gamma_only_in), + nlocal(nlocal_in), + nbands(nbands_in), + nspin(nspin_in), + dsize(dsize_in), + nb2d(nb2d_in), + dcolor(dcolor_in), + drank(drank_in), + myrank(myrank_in), + calculation(calculation_in), + ks_solver(ks_solver_in), + setup_2d(true) +{ + this->ParaV.nspin = nspin_in; +} +ORB_control::ORB_control() : + setup_2d(false) +{} ORB_control::~ORB_control() {} @@ -78,7 +111,7 @@ void ORB_control::set_orb_tables( #ifdef __NORMAL #else - if(GlobalV::CALCULATION=="test") + if(calculation=="test") { ModuleBase::timer::tick("ORB_control","set_orb_tables"); return; @@ -126,20 +159,21 @@ void ORB_control::clear_after_ions( } -void ORB_control::setup_2d_division(void) +void ORB_control::setup_2d_division(std::ofstream& ofs_running, + std::ofstream& ofs_warning) { ModuleBase::TITLE("ORB_control","setup_2d_division"); - GlobalV::ofs_running << "\n SETUP THE DIVISION OF H/S MATRIX" << std::endl; + ofs_running << "\n SETUP THE DIVISION OF H/S MATRIX" << std::endl; // (1) calculate nrow, ncol, nloc. - if (GlobalV::KS_SOLVER=="genelpa" || GlobalV::KS_SOLVER=="hpseps" || GlobalV::KS_SOLVER=="scalpack" - || GlobalV::KS_SOLVER=="selinv" || GlobalV::KS_SOLVER=="scalapack_gvx") + if (ks_solver=="genelpa" || ks_solver=="hpseps" || ks_solver=="scalpack" + || ks_solver=="selinv" || ks_solver=="scalapack_gvx") { - GlobalV::ofs_running << " divide the H&S matrix using 2D block algorithms." << std::endl; + ofs_running << " divide the H&S matrix using 2D block algorithms." << std::endl; #ifdef __MPI // storage form of H and S matrices on each processor // is determined in 'divide_HS_2d' subroutine - this->divide_HS_2d(DIAG_WORLD); + this->divide_HS_2d(DIAG_WORLD, ofs_running, ofs_warning); #else ModuleBase::WARNING_QUIT("LCAO_Matrix::init","diago method is not ready."); #endif @@ -147,16 +181,536 @@ void ORB_control::setup_2d_division(void) else { // the full matrix - this->ParaV.nloc = GlobalV::NLOCAL * GlobalV::NLOCAL; + this->ParaV.nloc = nlocal * nlocal; } // (2) set the trace, then we can calculate the nnr. // for 2d: calculate po.nloc first, then trace_loc_row and trace_loc_col // for O(N): calculate the three together. - this->set_trace(); + this->set_trace(ofs_running); +} + + +void ORB_control::set_parameters(std::ofstream& ofs_running, + std::ofstream& ofs_warning) +{ + ModuleBase::TITLE("ORB_control","set_parameters"); + + Parallel_Orbitals* pv = &this->ParaV; + // set loc_size + if(gamma_only)//xiaohui add 2014-12-21 + { + pv->loc_size=nbands/dsize; + + // mohan add 2012-03-29 + if(pv->loc_size==0) + { + ofs_warning << " loc_size=0" << " in proc " << myrank+1 << std::endl; + ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < DSIZE"); + } + + if (drankloc_size+=1; + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"local size",pv->loc_size); + + // set loc_sizes + delete[] pv->loc_sizes; + pv->loc_sizes = new int[dsize]; + ModuleBase::GlobalFunc::ZEROS(pv->loc_sizes, dsize); + + pv->lastband_in_proc = 0; + pv->lastband_number = 0; + int count_bands = 0; + for (int i=0; iloc_sizes[i]=nbands/dsize+1; + } + else + { + pv->loc_sizes[i]=nbands/dsize; + } + count_bands += pv->loc_sizes[i]; + if (count_bands >= nbands) + { + pv->lastband_in_proc = i; + pv->lastband_number = nbands - (count_bands - pv->loc_sizes[i]); + break; + } + } + } + else + { + pv->loc_size=nlocal/dsize; + + // mohan add 2012-03-29 + if(pv->loc_size==0) + { + ofs_warning << " loc_size=0" << " in proc " << myrank+1 << std::endl; + ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < DSIZE"); + } + + if (drankloc_size += 1; + } + if(pv->testpb) ModuleBase::GlobalFunc::OUT(ofs_running,"local size",pv->loc_size); + + // set loc_sizes + delete[] pv->loc_sizes; + pv->loc_sizes = new int[dsize]; + ModuleBase::GlobalFunc::ZEROS(pv->loc_sizes, dsize); + + pv->lastband_in_proc = 0; + pv->lastband_number = 0; + int count_bands = 0; + for (int i=0; iloc_sizes[i]=nlocal/dsize+1; + } + else + { + pv->loc_sizes[i]=nlocal/dsize; + } + count_bands += pv->loc_sizes[i]; + if (count_bands >= nbands) + { + pv->lastband_in_proc = i; + pv->lastband_number = nbands - (count_bands - pv->loc_sizes[i]); + break; + } + } + }//xiaohui add 2014-12-21 + + if (ks_solver=="hpseps") //LiuXh add 2021-09-06, clear memory, Z_LOC only used in hpseps solver + { + pv->Z_LOC = new double*[nspin]; + for(int is=0; isZ_LOC[is] = new double[pv->loc_size * nlocal]; + ModuleBase::GlobalFunc::ZEROS(pv->Z_LOC[is], pv->loc_size * nlocal); + } + pv->alloc_Z_LOC = true;//xiaohui add 2014-12-22 + } + + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"lastband_in_proc", pv->lastband_in_proc); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"lastband_number", pv->lastband_number); + + return; +} + + +#ifdef __MPI +// creat the 'comm_2D' stratege. +void ORB_control::mpi_creat_cart(MPI_Comm* comm_2D, + int prow, int pcol, std::ofstream& ofs_running) +{ + ModuleBase::TITLE("ORB_control","mpi_creat_cart"); + // the matrix is divided as ( dim[0] * dim[1] ) + int dim[2]; + int period[2]={1,1}; + int reorder=0; + dim[0]=prow; + dim[1]=pcol; + + if(this->ParaV.testpb) ofs_running << " dim = " << dim[0] << " * " << dim[1] << std::endl; + + MPI_Cart_create(DIAG_WORLD,2,dim,period,reorder,comm_2D); + return; +} +#endif + +#ifdef __MPI +void ORB_control::mat_2d(MPI_Comm vu, + const int &M_A, + const int &N_A, + const int &nb, + LocalMatrix &LM, + std::ofstream& ofs_running, + std::ofstream& ofs_warning) +{ + ModuleBase::TITLE("ORB_control", "mat_2d"); + + Parallel_Orbitals* pv = &this->ParaV; + + int dim[2]; + int period[2]; + int coord[2]; + int i,j,k,end_id; + int block; + + // (0) every processor get it's id on the 2D comm + // : ( coord[0], coord[1] ) + MPI_Cart_get(vu,2,dim,period,coord); + + // (1.1) how many blocks at least + // eg. M_A = 6400, nb = 64; + // so block = 10; + block=M_A/nb; + + // (1.2) If data remain, add 1. + if (block*nbtestpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Total Row Blocks Number",block); + + // mohan add 2010-09-12 + if(dim[0]>block) + { + ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; + ofs_warning << " but, the number of row blocks is " << block << std::endl; + ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no row blocks, try a smaller 'nb2d' parameter."); + } + + // (2.1) row_b : how many blocks for this processor. (at least) + LM.row_b=block/dim[0]; + + // (2.2) row_b : how many blocks in this processor. + // if there are blocks remain, some processors add 1. + if (coord[0]testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Local Row Block Number",LM.row_b); + + // (3) end_id indicates the last block belong to + // which processor. + if (block%dim[0]==0) + { + end_id=dim[0]-1; + } + else + { + end_id=block%dim[0]-1; + } + + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Ending Row Block in processor",end_id); + + // (4) row_num : how many rows in this processors : + // the one owns the last block is different. + if (coord[0]==end_id) + { + LM.row_num=(LM.row_b-1)*nb+(M_A-(block-1)*nb); + } + else + { + LM.row_num=LM.row_b*nb; + } + + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Local rows (including nb)",LM.row_num); + + // (5) row_set, it's a global index : + // save explicitly : every row in this processor + // belongs to which row in the global matrix. + delete[] LM.row_set; + LM.row_set= new int[LM.row_num]; + j=0; + for (i=0; iblock) + { + ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; + ofs_warning << " but, the number of column blocks is " << block << std::endl; + ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no column blocks."); + } + + LM.col_b=block/dim[1]; + if (coord[1]testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Local Row Block Number",LM.col_b); + + if (block%dim[1]==0) + { + end_id=dim[1]-1; + } + else + { + end_id=block%dim[1]-1; + } + + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Ending Row Block in processor",end_id); + + if (coord[1]==end_id) + { + LM.col_num=(LM.col_b-1)*nb+(M_A-(block-1)*nb); + } + else + { + LM.col_num=LM.col_b*nb; + } + + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"Local columns (including nb)",LM.row_num); + + delete[] LM.col_set; + LM.col_set = new int[LM.col_num]; + + j=0; + for (i=0; iblock) + { + ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; + ofs_warning << " but, the number of bands-row-block is " << block << std::endl; + ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no bands-row-blocks."); + } + int col_b_bands = block / dim[1]; + if (coord[1] < block % dim[1]) + { + col_b_bands++; + } + if (block%dim[1]==0) + { + end_id=dim[1]-1; + } + else + { + end_id=block%dim[1]-1; + } + if (coord[1]==end_id) + { + pv->ncol_bands=(col_b_bands-1)*nb+(N_A-(block-1)*nb); + } + else + { + pv->ncol_bands=col_b_bands*nb; + } + pv->nloc_wfc = pv->ncol_bands * LM.row_num; + + return; } +#endif + + +#ifdef __MPI +// A : contains total matrix element in processor. +void ORB_control::data_distribution( + MPI_Comm comm_2D, + const std::string &file, + const int &n, + const int &nb, + double *A, + const LocalMatrix &LM) +{ + ModuleBase::TITLE("ORB_control", "data_distribution"); + Parallel_Orbitals* pv = &this->ParaV; + MPI_Comm comm_row; + MPI_Comm comm_col; + MPI_Status status; + + int dim[2]; + int period[2]; + int coord[2]; + MPI_Cart_get(comm_2D,2,dim,period,coord); + + if(pv->testpb) ofs_running << "\n dim = " << dim[0] << " * " << dim[1] << std::endl; + if(pv->testpb) ofs_running << " coord = ( " << coord[0] << " , " << coord[1] << ")." << std::endl; + if(pv->testpb) ofs_running << " n = " << n << std::endl; + + mpi_sub_col(comm_2D,&comm_col); + mpi_sub_row(comm_2D,&comm_row); + + // total number of processors + const int myid = coord[0]*dim[1]+coord[1]; + + // the matrix is n * n + double* ele_val = new double[n]; + double* val = new double[n]; + int* sends = new int[dim[1]]; + int* fpt = new int[dim[1]]; + int* snd = new int[dim[1]]; + int* temp = new int[dim[1]]; + + ModuleBase::GlobalFunc::ZEROS(ele_val, n); + ModuleBase::GlobalFunc::ZEROS(val, n); + ModuleBase::GlobalFunc::ZEROS(sends, dim[1]); + ModuleBase::GlobalFunc::ZEROS(fpt, dim[1]); + ModuleBase::GlobalFunc::ZEROS(snd, dim[1]); + ModuleBase::GlobalFunc::ZEROS(temp, dim[1]); + + // the columes of matrix is divided by 'dim[1]' 'rows of processors'. + // collect all information of each 'rows of processors' + // collection data is saved in 'sends' + snd[coord[1]] = LM.col_num; + MPI_Allgather(&snd[coord[1]],1,MPI_INT,sends,1,MPI_INT,comm_row); + + // fpt : start column index after applied 'mat_2d' reorder algorithms + // to matrix. + fpt[0] = 0; + for (int i=1; impi_creat_cart(&pv->comm_2D,dim[0],dim[1]); @@ -193,12 +747,12 @@ void ORB_control::readin( // call mat_2d this->mat_2d(pv->comm_2D, nlocal_tot,nlocal_tot,pv->nb,pv->MatrixInfo); - pv->loc_size=nlocal_tot/GlobalV::DSIZE; - if (GlobalV::DRANKloc_size=nlocal_tot/dsize; + if (drankdata_distribution(pv->comm_2D,fa,nlocal_tot,pv->nb,A,pv->MatrixInfo); - GlobalV::ofs_running << "\n Data distribution of S." << std::endl; + ofs_running << "\n Data distribution of S." << std::endl; this->data_distribution(pv->comm_2D,fb,nlocal_tot,pv->nb,B,pv->MatrixInfo); time1=MPI_Wtime(); @@ -218,19 +772,20 @@ void ORB_control::readin( char uplo = 'U'; pdgseps(pv->comm_2D,nlocal_tot,pv->nb,A,B,Z,eigen,pv->MatrixInfo,uplo,pv->loc_size,loc_pos); time2=MPI_Wtime(); - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"time1",time1); - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"time2",time2); + ModuleBase::GlobalFunc::OUT(ofs_running,"time1",time1); + ModuleBase::GlobalFunc::OUT(ofs_running,"time2",time2); //this->gath_eig(comm,n,eigvr,Z); - GlobalV::ofs_running << "\n " << std::setw(6) << "Band" << std::setw(25) << "Ry" << std::setw(25) << " eV" << std::endl; + ofs_running << "\n " << std::setw(6) << "Band" << std::setw(25) << "Ry" << std::setw(25) << " eV" << std::endl; for(int i=0; inspin; is++) { delete[] Z_LOC[is]; } @@ -59,70 +58,70 @@ bool Parallel_Orbitals::in_this_processor(const int &iw1_all, const int &iw2_all return true; } -void ORB_control::set_trace(void) +void ORB_control::set_trace(std::ofstream& ofs_running) { ModuleBase::TITLE("ORB_control","set_trace"); - assert(GlobalV::NLOCAL > 0); + assert(nlocal > 0); Parallel_Orbitals* pv = &this->ParaV; delete[] pv->trace_loc_row; delete[] pv->trace_loc_col; - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"trace_loc_row dimension",GlobalV::NLOCAL); - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"trace_loc_col dimension",GlobalV::NLOCAL); + ModuleBase::GlobalFunc::OUT(ofs_running,"trace_loc_row dimension",nlocal); + ModuleBase::GlobalFunc::OUT(ofs_running,"trace_loc_col dimension",nlocal); - pv->trace_loc_row = new int[GlobalV::NLOCAL]; - pv->trace_loc_col = new int[GlobalV::NLOCAL]; + pv->trace_loc_row = new int[nlocal]; + pv->trace_loc_col = new int[nlocal]; // mohan update 2011-04-07 - for(int i=0; itrace_loc_row[i] = -1; pv->trace_loc_col[i] = -1; } - ModuleBase::Memory::record("ORB_control","trace_loc_row",GlobalV::NLOCAL,"int"); - ModuleBase::Memory::record("ORB_control","trace_loc_col",GlobalV::NLOCAL,"int"); + ModuleBase::Memory::record("ORB_control","trace_loc_row",nlocal,"int"); + ModuleBase::Memory::record("ORB_control","trace_loc_col",nlocal,"int"); - if(GlobalV::KS_SOLVER=="lapack" - || GlobalV::KS_SOLVER=="cg" - || GlobalV::KS_SOLVER=="dav") //xiaohui add 2013-09-02 + if(ks_solver=="lapack" + || ks_solver=="cg" + || ks_solver=="dav") //xiaohui add 2013-09-02 { std::cout << " common settings for trace_loc_row and trace_loc_col " << std::endl; - for (int i=0; itrace_loc_row[i] = i; pv->trace_loc_col[i] = i; } - pv->nrow = GlobalV::NLOCAL; - pv->ncol = GlobalV::NLOCAL; + pv->nrow = nlocal; + pv->ncol = nlocal; } #ifdef __MPI - else if(GlobalV::KS_SOLVER=="scalpack" || GlobalV::KS_SOLVER=="genelpa" || GlobalV::KS_SOLVER=="hpseps" - || GlobalV::KS_SOLVER=="selinv" || GlobalV::KS_SOLVER=="scalapack_gvx") //xiaohui add 2013-09-02 + else if(ks_solver=="scalpack" || ks_solver=="genelpa" || ks_solver=="hpseps" + || ks_solver=="selinv" || ks_solver=="scalapack_gvx") //xiaohui add 2013-09-02 { - // GlobalV::ofs_running << " nrow=" << nrow << std::endl; + // ofs_running << " nrow=" << nrow << std::endl; for (int irow=0; irow< pv->nrow; irow++) { int global_row = pv->MatrixInfo.row_set[irow]; pv->trace_loc_row[global_row] = irow; - // GlobalV::ofs_running << " global_row=" << global_row + // ofs_running << " global_row=" << global_row // << " trace_loc_row=" << pv->trace_loc_row[global_row] << std::endl; } - // GlobalV::ofs_running << " ncol=" << ncol << std::endl; + // ofs_running << " ncol=" << ncol << std::endl; for (int icol=0; icol< pv->ncol; icol++) { int global_col = pv->MatrixInfo.col_set[icol]; pv->trace_loc_col[global_col] = icol; - // GlobalV::ofs_running << " global_col=" << global_col + // ofs_running << " global_col=" << global_col // << " trace_loc_col=" << pv->trace_loc_row[global_col] << std::endl; } } #endif else { - std::cout << " Parallel Orbial, GlobalV::DIAGO_TYPE = " << GlobalV::KS_SOLVER << std::endl; + std::cout << " Parallel Orbial, DIAGO_TYPE = " << ks_solver << std::endl; ModuleBase::WARNING_QUIT("ORB_control::set_trace","Check ks_solver."); } @@ -130,17 +129,17 @@ void ORB_control::set_trace(void) // print the trace for test. //--------------------------- /* - GlobalV::ofs_running << " " << std::setw(10) << "GlobalRow" << std::setw(10) << "LocalRow" << std::endl; - for(int i=0; itrace_loc_row[i] << std::endl; + ofs_running << " " << std::setw(10) << i << std::setw(10) << pv->trace_loc_row[i] << std::endl; } - GlobalV::ofs_running << " " << std::setw(10) << "GlobalCol" << std::setw(10) << "LocalCol" << std::endl; - for(int j=0; j0); - assert(GlobalV::DSIZE > 0); + assert(nlocal>0); + assert(dsize > 0); Parallel_Orbitals* pv = &this->ParaV; #ifdef __MPI DIAG_HPSEPS_WORLD=DIAG_WORLD; #endif - if(GlobalV::DCOLOR!=0) return; // mohan add 2012-01-13 + if(dcolor!=0) return; // mohan add 2012-01-13 // get the 2D index of computer. - pv->dim0 = (int)sqrt((double)GlobalV::DSIZE); //mohan update 2012/01/13 + pv->dim0 = (int)sqrt((double)dsize); //mohan update 2012/01/13 //while (GlobalV::NPROC_IN_POOL%dim0!=0) - while (GlobalV::DSIZE%pv->dim0!=0) + while (dsize%pv->dim0!=0) { pv->dim0 = pv->dim0 - 1; } assert(pv->dim0 > 0); - pv->dim1=GlobalV::DSIZE/pv->dim0; + pv->dim1=dsize/pv->dim0; - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"dim0",pv->dim0); - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"dim1",pv->dim1); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"dim0",pv->dim0); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"dim1",pv->dim1); #ifdef __MPI // mohan add 2011-04-16 - if(GlobalV::NB2D==0) + if(nb2d==0) { - if(GlobalV::NLOCAL>0) pv->nb = 1; - if(GlobalV::NLOCAL>500) pv->nb = 32; - if(GlobalV::NLOCAL>1000) pv->nb = 64; + if(nlocal>0) pv->nb = 1; + if(nlocal>500) pv->nb = 32; + if(nlocal>1000) pv->nb = 64; } - else if(GlobalV::NB2D>0) + else if(nb2d>0) { - pv->nb = GlobalV::NB2D; // mohan add 2010-06-28 + pv->nb = nb2d; // mohan add 2010-06-28 } - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"nb2d", pv->nb); + ModuleBase::GlobalFunc::OUT(ofs_running,"nb2d", pv->nb); - this->set_parameters(); + this->set_parameters(ofs_running, ofs_warning); // call mpi_creat_cart - this->mpi_creat_cart(&pv->comm_2D,pv->dim0,pv->dim1); + this->mpi_creat_cart(&pv->comm_2D,pv->dim0,pv->dim1, ofs_running); // call mat_2d - this->mat_2d(pv->comm_2D, GlobalV::NLOCAL, GlobalV::NBANDS, pv->nb, pv->MatrixInfo); + this->mat_2d(pv->comm_2D, nlocal, nbands, pv->nb, + pv->MatrixInfo, ofs_running, ofs_warning); // mohan add 2010-06-29 pv->nrow = pv->MatrixInfo.row_num; @@ -243,31 +244,32 @@ void ORB_control::divide_HS_2d pv->nloc = pv->MatrixInfo.col_num * pv->MatrixInfo.row_num; // init blacs context for genelpa - if(GlobalV::KS_SOLVER=="genelpa" || GlobalV::KS_SOLVER=="scalapack_gvx") + if(ks_solver=="genelpa" || ks_solver=="scalapack_gvx") { - pv->blacs_ctxt=cart2blacs(pv->comm_2D, pv->dim0, pv->dim1, GlobalV::NLOCAL, GlobalV::NBANDS, pv->nb, pv->nrow, pv->desc, pv->desc_wfc); + pv->blacs_ctxt = cart2blacs(pv->comm_2D, pv->dim0, pv->dim1, + nlocal, nbands, pv->nb, pv->nrow, pv->desc, pv->desc_wfc); } #else // single processor used. - pv->nb = GlobalV::NLOCAL; - pv->nrow = GlobalV::NLOCAL; - pv->ncol = GlobalV::NLOCAL; - pv->nloc = GlobalV::NLOCAL * GlobalV::NLOCAL; - this->set_parameters(); - pv->MatrixInfo.row_b = 1; - pv->MatrixInfo.row_num = GlobalV::NLOCAL; + pv->nb = nlocal; + pv->nrow = nlocal; + pv->ncol = nlocal; + pv->nloc = nlocal * nlocal; + this->set_parameters(ofs_running, ofs_warning); + pv->MatrixInfo.row_b = 1; + pv->MatrixInfo.row_num = nlocal; delete[] pv->MatrixInfo.row_set; - pv->MatrixInfo.row_set = new int[GlobalV::NLOCAL]; - for(int i=0; iMatrixInfo.row_set = new int[nlocal]; + for(int i=0; iMatrixInfo.row_set[i]=i; } pv->MatrixInfo.row_pos=0; pv->MatrixInfo.col_b = 1; - pv->MatrixInfo.col_num = GlobalV::NLOCAL; + pv->MatrixInfo.col_num = nlocal; delete[] pv->MatrixInfo.col_set; - pv->MatrixInfo.col_set = new int[GlobalV::NLOCAL]; - for(int i=0; iMatrixInfo.col_set = new int[nlocal]; + for(int i=0; iMatrixInfo.col_set[i]=i; } @@ -275,8 +277,8 @@ void ORB_control::divide_HS_2d #endif assert(pv->nloc>0); - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"MatrixInfo.row_num",pv->MatrixInfo.row_num); - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"MatrixInfo.col_num",pv->MatrixInfo.col_num); - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"nloc",pv->nloc); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"MatrixInfo.row_num",pv->MatrixInfo.row_num); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"MatrixInfo.col_num",pv->MatrixInfo.col_num); + if(pv->testpb)ModuleBase::GlobalFunc::OUT(ofs_running,"nloc",pv->nloc); return; } diff --git a/source/src_parallel/parallel_orbitals.h b/source/module_orbital/parallel_orbitals.h similarity index 88% rename from source/src_parallel/parallel_orbitals.h rename to source/module_orbital/parallel_orbitals.h index de63476a053..4d4862bb06b 100644 --- a/source/src_parallel/parallel_orbitals.h +++ b/source/module_orbital/parallel_orbitals.h @@ -39,11 +39,11 @@ struct Parallel_Orbitals int lastband_in_proc; int lastband_number; - //--------------------------------------- - // number of elements in H or S matrix, - // nnr -> 2D block distribution; - //--------------------------------------- - int nnr; + ///--------------------------------------- + /// number of elements(basis-pairs) in this processon + /// on all adjacent atoms-pairs(2D division) + ///--------------------------------------- + int nnr; int *nlocdim; int *nlocstart; @@ -57,7 +57,8 @@ struct Parallel_Orbitals #endif /// only used in hpseps-diago - int* loc_sizes; + int nspin = 1; + int* loc_sizes; int loc_size; bool alloc_Z_LOC; //xiaohui add 2014-12-22 double** Z_LOC; //xiaohui add 2014-06-19 @@ -67,15 +68,11 @@ struct Parallel_Orbitals // test parameter int testpb; - + /// check whether a basis element is in this processor /// (check whether local-index > 0 ) bool in_this_processor(const int& iw1_all, const int& iw2_all) const; - /// number of elements(basis-pairs) in this processon - /// on all adjacent atoms-pairs(2D division) - void cal_nnr(); - }; diff --git a/source/module_orbital/test/1_snap_equal_test.cpp b/source/module_orbital/test/1_snap_equal_test.cpp new file mode 100644 index 00000000000..a3fe5423cb7 --- /dev/null +++ b/source/module_orbital/test/1_snap_equal_test.cpp @@ -0,0 +1,72 @@ +#include +#include"ORB_unittest.h" + +//Test whether the 2-center-int results +// and its derivative from two clases are equal. +// - ORB_gen_table::snap_psipsi(job=0) and Center2_Orb::Orb11::cal_overlap +// - ORB_gen_table::snap_psipsi(job=1) and Center2_Orb::Orb11::cal_grad_overlap +TEST_F(test_orb, equal_test) +{ + + this->set_center2orbs(); + //equal test + //orb + double olm_0[1] = { 0 }; + double olm_1[3] = { 0,0,0 }; + //center2orb + double clm_0 = 0; + ModuleBase::Vector3 clm_1; + + //test parameters + const double rmax = 5; //Ry + srand((unsigned)time(NULL)); + ModuleBase::Vector3 R1(0, 0, 0); + ModuleBase::Vector3 R2(randr(rmax), randr(rmax), randr(rmax)); + std::cout << "random R2=(" << R2.x << "," << R2.y << "," << R2.z << ")" << std::endl; + ModuleBase::Vector3 dR = ModuleBase::Vector3(0.001, 0.001, 0.001); + //4. calculate overlap and grad_overlap by both methods + int T1 = 0; + + for (int T2 = 0;T2 < ORB.get_ntype();++T2) + { + for (int L1 = 0;L1 < ORB.Phi[T1].getLmax();++L1) + { + for (int N1 = 0;N1 < ORB.Phi[T1].getNchi(L1);++N1) + { + for (int L2 = 0;L2 < ORB.Phi[T2].getLmax();++L2) + { + for (int N2 = 0;N2 < ORB.Phi[T2].getNchi(L2);++N2) + { + for (int m1 = 0;m1 < 2 * L1 + 1;++m1) + { + for (int m2 = 0;m2 < 2 * L2 + 1;++m2) + { + OGT.snap_psipsi( + ORB, olm_0, 0, 'S', + R1, T1, L1, m1, N1, + R2, T2, L2, m2, N2, + 1, NULL); + OGT.snap_psipsi( + ORB, olm_1, 1, 'S', + R1, T1, L1, m1, N1, + R2, T2, L2, m2, N2, + 1, NULL); + //std::cout << this->mock_center2_orb11[T1][T2][L1][N1][L2][N2]->cal_overlap(R1, R2, m1, m2); + clm_0 = + test_center2_orb11[T1][T2][L1][N1][L2][N2]->cal_overlap(R1, R2, m1, m2); + clm_1 = + test_center2_orb11[T1][T2][L1][N1][L2][N2]->cal_grad_overlap(R1, R2, m1, m2); + EXPECT_NEAR(olm_0[0], clm_0, 1e-10); + EXPECT_NEAR(olm_1[0], clm_1.x, 1e-10); + EXPECT_NEAR(olm_1[1], clm_1.y, 1e-10); + EXPECT_NEAR(olm_1[2], clm_1.z, 1e-10); + ModuleBase::GlobalFunc::ZEROS(olm_1, 3); + } + } + } + + } + } + } + } +} \ No newline at end of file diff --git a/source/module_orbital/test/CMakeLists.txt b/source/module_orbital/test/CMakeLists.txt new file mode 100644 index 00000000000..c4f71f09fa1 --- /dev/null +++ b/source/module_orbital/test/CMakeLists.txt @@ -0,0 +1,47 @@ +remove_definitions(-D__MPI) + +list(APPEND depend_files + ../../module_base/math_integral.cpp + ../../module_base/math_sphbes.cpp + ../../module_base/math_polyint.cpp + ../../module_base/math_ylmreal.cpp + ../../module_base/ylm.cpp + ../../module_base/memory.cpp + ../../module_base/complexarray.cpp + ../../module_base/complexmatrix.cpp + ../../module_base/matrix.cpp + ../../module_base/realarray.cpp + ../../module_base/intarray.cpp + ../../module_base/sph_bessel.cpp + ../../module_base/sph_bessel_recursive-d1.cpp + ../../module_base/sph_bessel_recursive-d2.cpp + ../../module_base/tool_title.cpp + ../../module_base/tool_quit.cpp + ../../module_base/tool_check.cpp + ../../module_base/timer.cpp + ../../module_base/mathzone_add1.cpp + ../../module_base/global_variable.cpp + ../../module_base/global_function.cpp + ../../module_base/global_file.cpp + ../ORB_control.cpp + ../ORB_read.cpp + ../ORB_atomic.cpp + ../ORB_atomic_lm.cpp + ../ORB_nonlocal.cpp + ../ORB_nonlocal_lm.cpp + ../ORB_gaunt_table.cpp + ../ORB_table_beta.cpp + ../ORB_table_phi.cpp + ../ORB_table_alpha.cpp + ../ORB_gen_tables.cpp + ../parallel_orbitals.cpp + ../../src_lcao/center2_orb-orb11.cpp + ) +AddTest( + TARGET orbital_equal_test + LIBS ${math_libs} + SOURCES 1_snap_equal_test.cpp ORB_unittest.cpp + ${depend_files} +) +install(DIRECTORY GaAs DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/../../../tests) +install(DIRECTORY GaAs DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) \ No newline at end of file diff --git a/tests/module_orb/GaAs/As_dojo.orb b/source/module_orbital/test/GaAs/As_dojo.orb similarity index 100% rename from tests/module_orb/GaAs/As_dojo.orb rename to source/module_orbital/test/GaAs/As_dojo.orb diff --git a/tests/module_orb/GaAs/Ga_dojo.orb b/source/module_orbital/test/GaAs/Ga_dojo.orb similarity index 100% rename from tests/module_orb/GaAs/Ga_dojo.orb rename to source/module_orbital/test/GaAs/Ga_dojo.orb diff --git a/tests/module_orb/GaAs/README b/source/module_orbital/test/GaAs/README similarity index 100% rename from tests/module_orb/GaAs/README rename to source/module_orbital/test/GaAs/README diff --git a/tests/module_orb/GaAs/STRU b/source/module_orbital/test/GaAs/STRU similarity index 100% rename from tests/module_orb/GaAs/STRU rename to source/module_orbital/test/GaAs/STRU diff --git a/tests/module_orb/src/Makefile b/source/module_orbital/test/Makefile similarity index 83% rename from tests/module_orb/src/Makefile rename to source/module_orbital/test/Makefile index 5887dd5b5cb..5fb47a29ffc 100644 --- a/tests/module_orb/src/Makefile +++ b/source/module_orbital/test/Makefile @@ -5,8 +5,10 @@ #========================== CPLUSPLUS = icpc CPLUSPLUS_MPI = icpc -FFTW_DIR = /home/liwenfei/codes/FFTW3 +FFTW_DIR = /home/fortneu49/soft/fftw-3.3.8 OBJ_DIR = orb_obj +GTEST_DIR = /usr/local/lib +GMOCK_DIR = /usr/local/lib NP = 1 #========================== @@ -20,7 +22,7 @@ FFTW_LIB = -L${FFTW_LIB_DIR} -lfftw3 -Wl,-rpath=${FFTW_LIB_DIR} #========================== # LIBS and INCLUDES #========================== -LIBS = -lifcore -lm -lpthread ${FFTW_LIB} +LIBS = -lifcore -lm -lpthread ${FFTW_LIB} ${GTEST_DIR}/libgtest.a ${GMOCK_DIR}/libgmock.a #========================== # OPTIMIZE OPTIONS @@ -29,13 +31,14 @@ INCLUDES = -I. -Icommands -I${FFTW_INCLUDE_DIR} # -pedantic turns off more extensions and generates more warnings # -xHost generates instructions for the highest instruction set available on the compilation host processor -OPTS = ${INCLUDES} -Ofast -std=c++11 -simd -march=native -xHost -m64 -qopenmp -Werror -Wall -pedantic -g +OPTS = ${INCLUDES} -Ofast -std=c++14 -simd -march=native -m64 -qopenmp -Werror -Wall -pedantic -g include Makefile.Objects VPATH=../../../source/module_base\ :../../../source/src_global\ :../../../source/module_orbital\ +:../../../source/src_lcao\ :./\ #========================== @@ -46,6 +49,8 @@ HONG= -DMETIS -DMKL_ILP64 -D__NORMAL -D__ORBITAL ${HONG_FFTW} FP_OBJS_0=main.o\ $(OBJS_BASE)\ $(OBJS_ORBITAL)\ +$(OBJS_CTO)\ +$(OBJS_TEST)\ FP_OBJS=$(patsubst %.o, ${OBJ_DIR}/%.o, ${FP_OBJS_0}) @@ -60,7 +65,7 @@ init : @ if [ ! -d $(OBJ_DIR) ]; then mkdir $(OBJ_DIR); fi @ if [ ! -d $(OBJ_DIR)/README ]; then echo "This directory contains all of the .o files" > $(OBJ_DIR)/README; fi -serial : ${FP_OBJS} +serial : ${FP_OBJS} ${CPLUSPLUS} ${OPTS} $(FP_OBJS) ${LIBS} -o ${VERSION}.x #========================== diff --git a/tests/module_orb/src/Makefile.Objects b/source/module_orbital/test/Makefile.Objects similarity index 81% rename from tests/module_orb/src/Makefile.Objects rename to source/module_orbital/test/Makefile.Objects index 81b3dc995f8..0ac16871004 100644 --- a/tests/module_orb/src/Makefile.Objects +++ b/source/module_orbital/test/Makefile.Objects @@ -39,4 +39,10 @@ ORB_table_beta.o\ ORB_table_phi.o\ ORB_table_alpha.o\ ORB_gen_tables.o\ -ORB_unittest.o\ + +OBJS_CTO=center2_orb-orb11.o\ + +OBJS_TEST=ORB_unittest.o\ +1_snap_equal_test.o\ +#2_snap_finite_diff_test.o\ +#3_center2_finite_diff_test.o\ \ No newline at end of file diff --git a/source/module_orbital/test/ORB_unittest.cpp b/source/module_orbital/test/ORB_unittest.cpp new file mode 100644 index 00000000000..531d921003a --- /dev/null +++ b/source/module_orbital/test/ORB_unittest.cpp @@ -0,0 +1,224 @@ +#include "ORB_unittest.h" + +void test_orb::SetUp() +{ + //test constructor + /*Center2_Orb::Orb11 testcto = Center2_Orb::Orb11( + ORB.Phi[0].PhiLN(0, 0), + ORB.Phi[0].PhiLN(0, 0), + OGT.MOT, Center2_MGT);*/ + // 1. setup orbitals + this->ofs_running.open("log.txt"); + this->count_ntype(); + this->set_files(); + this->set_ekcut(); + + + //2. setup 2-center-integral tables by basic methods + // not including center2orb, it will be set up when needed + // in some test cases. + this->set_orbs(); + //this->set_center2orbs(); + +} + +void test_orb::TearDown() +{ + int* nproj = new int[ORB.get_ntype()]; + for (int i = 0;i < ORB.get_ntype();++i) + nproj[i] = 0; + ooo.clear_after_ions(OGT, ORB, 0, nproj); + delete[] nproj; + return; +} + +void test_orb::set_ekcut() +{ + std::cout << "set lcao_ecut from LCAO files" << std::endl; + //set as max of ekcut from every element + + lcao_ecut=0.0; + std::ifstream in_ao; + + for(int it=0;itcase_dir+ORB.orbital_file[it].c_str())); + if(!in_ao) + { + std::cout << "error : cannot find LCAO file : " << ORB.orbital_file[it] << std::endl; + } + ORB.orbital_file[it] = this->case_dir + ORB.orbital_file[it].c_str(); + string word; + while (in_ao.good()) + { + in_ao >> word; + if(word == "Cutoff(Ry)") break; + } + in_ao >> ek_current; + lcao_ecut = std::max(lcao_ecut,ek_current); + + in_ao.close(); + } + + ORB.ecutwfc=lcao_ecut; + std::cout << "lcao_ecut : " << lcao_ecut << std::endl; + return; +} + +void test_orb::set_orbs() +{ + + ooo.read_orb_first( + ofs_running, + ORB, + ntype_read, + lmax, + lcao_ecut, + lcao_dk, + lcao_dr, + lcao_rmax, + 0, + 0, + 1,//force + 0);//myrank + + int* nproj = new int[ORB.get_ntype()]; + for (int i = 0;i < ORB.get_ntype();++i) + nproj[i] = 0; + const Numerical_Nonlocal beta_[ORB.get_ntype()]; + + ooo.set_orb_tables( + ofs_running, + OGT, + ORB, + lat0, + 0, //no out_descriptor + lmax, + 0, //no nproj + nproj, + beta_); + + delete[] nproj; + return; +} + +void test_orb::set_files() +{ + std::cout << "read names of atomic basis set files" << std::endl; + std::ifstream ifs((this->case_dir + "STRU"),std::ios::in); + + ModuleBase::GlobalFunc::SCAN_BEGIN(ifs,"NUMERICAL_ORBITAL"); + ORB.read_in_flag = true; + + for(int it=0;it> ofile; + ORB.orbital_file.push_back(ofile); + + std::cout << "Numerical orbital file : " << ofile << std::endl; + } + + return; +} + +void test_orb::count_ntype() +{ + std::cout << "count number of atom types" << std::endl; + std::cout << this->case_dir +"STRU" << std::endl; + std::ifstream ifs( (this->case_dir+ "STRU"), std::ios::in); + + if (!ifs) + { + std::cout << "ERROR : file STRU does not exist" < dR = ModuleBase::Vector3(0.001, 0.001, 0.001); + std::cout << this->test_center2_orb11[0][0][0][0][0][0].cal_overlap(R1, R2, 0, 0); +}*/ \ No newline at end of file diff --git a/source/module_orbital/test/ORB_unittest.h b/source/module_orbital/test/ORB_unittest.h new file mode 100644 index 00000000000..828b4890b6f --- /dev/null +++ b/source/module_orbital/test/ORB_unittest.h @@ -0,0 +1,75 @@ +#ifndef _ORBUNITTEST_ +#define _ORBUNITTEST_ + +#include "gtest/gtest.h" +#include "module_orbital/ORB_control.h" +#include "module_base/global_function.h" +#include "src_lcao/center2_orb-orb11.h" +//#include "mock_center2.h" +#include +#include +#include + +#include +#include +#include + +using namespace std; + +class test_orb : public testing::Test +{ +protected: + void SetUp() override; + void TearDown() override; +public: + + LCAO_Orbitals ORB; + ORB_gen_tables OGT; + ORB_gaunt_table Center2_MGT; //gaunt table used in center2orb + ORB_control ooo; + std::ofstream ofs_running; + + std::map < size_t, + std::map>>>>>> test_center2_orb11; +/* + std::map < size_t, + std::map>>>>>> mock_center2_orb11; +*/ + void count_ntype(); //from STRU, count types of elements + void set_files(); //from STRU, read names of LCAO files + void set_ekcut(); //from LCAO files, read and set ekcut + void set_orbs(); //interface to Read_PAO + void set_center2orbs(); //interface to Center2orb + template + void set_single_c2o(int TA, int TB, int LA, int NA, int LB, int NB); + double randr(double Rmax); + void gen_table_center2(); + + + bool force_flag = 0; + int my_rank = 0; + int ntype_read; + + double lcao_ecut = 0; // (Ry) + double lcao_dk = 0.01; + double lcao_dr = 0.01; + double lcao_rmax = 30; // (a.u.) + + int out_descriptor = 0; + int out_r_matrix = 0; + + int lmax=1; + double lat0 = 1.0; + string case_dir = "./GaAs/"; +}; +#endif diff --git a/source/module_orbital/2_UnitTests/README b/source/module_orbital/test/README similarity index 100% rename from source/module_orbital/2_UnitTests/README rename to source/module_orbital/test/README diff --git a/source/module_orbital/test/orb_obj/README b/source/module_orbital/test/orb_obj/README new file mode 100644 index 00000000000..3f49259cf92 --- /dev/null +++ b/source/module_orbital/test/orb_obj/README @@ -0,0 +1 @@ +This directory contains all of the .o files diff --git a/source/run_lcao.cpp b/source/run_lcao.cpp index 04ba3703149..b2194c970be 100644 --- a/source/run_lcao.cpp +++ b/source/run_lcao.cpp @@ -7,21 +7,18 @@ #include "module_neighbor/sltk_atom_arrange.h" #include "src_lcao/LOOP_cell.h" #include "src_io/print_info.h" -#include "module_symmetry/symmetry.h" #include "src_lcao/run_md_lcao.h" -#ifdef __DEEPKS -#include "module_deepks/LCAO_deepks.h" -#endif Run_lcao::Run_lcao(){} Run_lcao::~Run_lcao(){} -void Run_lcao::lcao_line(void) +void Run_lcao::lcao_line(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_lcao","lcao_line"); ModuleBase::timer::tick("Run_lcao", "lcao_line"); - + + //-----------------------init Cell-------------------------- // Setup the unitcell. // improvement: a) separating the first reading of the atom_card and subsequent // cell relaxation. b) put GlobalV::NLOCAL and GlobalV::NBANDS as input parameters @@ -48,163 +45,106 @@ void Run_lcao::lcao_line(void) GlobalV::SEARCH_RADIUS, GlobalV::test_atom_input, INPUT.test_just_neighbor); - } - // setup GlobalV::NBANDS - // Yu Liu add 2021-07-03 - GlobalC::CHR.cal_nelec(); - - // mohan add 2010-09-06 - // Yu Liu move here 2021-06-27 - // because the number of element type - // will easily be ignored, so here - // I warn the user again for each type. - for(int it=0; itInit(INPUT, GlobalC::ucell); + //------------------------------------------------------------ + //------------------init Basis_lcao---------------------- + // Init Basis should be put outside of Ensolver. // * reading the localized orbitals/projectors // * construct the interpolation tables. - ORB_control orb_con; + ORB_control orb_con( + GlobalV::GAMMA_ONLY_LOCAL, + GlobalV::NLOCAL, GlobalV::NBANDS, + GlobalV::NSPIN, GlobalV::DSIZE, + GlobalV::NB2D, GlobalV::DCOLOR, + GlobalV::DRANK, GlobalV::MY_RANK, + GlobalV::CALCULATION, GlobalV::KS_SOLVER); + Init_Basis_lcao(orb_con, INPUT, GlobalC::ucell); + //------------------init Basis_lcao---------------------- + + + //---------------------------MD/Relax------------------ + if (GlobalV::CALCULATION == "md") + { + Run_MD_LCAO run_md_lcao(orb_con.ParaV); + run_md_lcao.opt_cell(orb_con, p_ensolver); + } + else // cell relaxations + { + LOOP_cell lc(orb_con.ParaV); + //keep wfc_gamma or wfc_k remaining + lc.opt_cell(orb_con, p_ensolver); + } + //---------------------------MD/Relax------------------ + + ModuleBase::timer::tick("Run_lcao","lcao_line"); + return; +} + +void Run_lcao::Init_Basis_lcao(ORB_control& orb_con, Input& inp, UnitCell_pseudo& ucell) +{ + // * reading the localized orbitals/projectors + // * construct the interpolation tables. orb_con.read_orb_first( + GlobalV::ofs_running, + GlobalC::ORB, + ucell.ntype, + ucell.lmax, + inp.lcao_ecut, + inp.lcao_dk, + inp.lcao_dr, + inp.lcao_rmax, + GlobalV::out_descriptor, + inp.out_r_matrix, + GlobalV::FORCE, + GlobalV::MY_RANK); + + ucell.infoNL.setupNonlocal( + ucell.ntype, + ucell.atoms, GlobalV::ofs_running, - GlobalC::ORB, - GlobalC::ucell.ntype, - GlobalC::ucell.lmax, - INPUT.lcao_ecut, - INPUT.lcao_dk, - INPUT.lcao_dr, - INPUT.lcao_rmax, - GlobalV::out_descriptor, - INPUT.out_r_matrix, - GlobalV::FORCE, - GlobalV::MY_RANK); - - GlobalC::ucell.infoNL.setupNonlocal( - GlobalC::ucell.ntype, - GlobalC::ucell.atoms, - GlobalV::ofs_running, - GlobalC::ORB - ); + GlobalC::ORB); #ifdef __MPI orb_con.set_orb_tables( GlobalV::ofs_running, GlobalC::UOT, GlobalC::ORB, - GlobalC::ucell.lat0, + ucell.lat0, GlobalV::out_descriptor, Exx_Abfs::Lmax, - GlobalC::ucell.infoNL.nprojmax, - GlobalC::ucell.infoNL.nproj, - GlobalC::ucell.infoNL.Beta); + ucell.infoNL.nprojmax, + ucell.infoNL.nproj, + ucell.infoNL.Beta); #else int Lmax=0; orb_con.set_orb_tables( GlobalV::ofs_running, GlobalC::UOT, GlobalC::ORB, - GlobalC::ucell.lat0, + ucell.lat0, GlobalV::out_descriptor, Lmax, - GlobalC::ucell.infoNL.nprojmax, - GlobalC::ucell.infoNL.nproj, - GlobalC::ucell.infoNL.Beta); + ucell.infoNL.nprojmax, + ucell.infoNL.nproj, + ucell.infoNL.Beta); #endif - orb_con.setup_2d_division(); -//-------------------------------------- -// cell relaxation should begin here -//-------------------------------------- - - // Initalize the plane wave basis set - GlobalC::pw.gen_pw(GlobalV::ofs_running, GlobalC::ucell, GlobalC::kv); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running,"INIT PLANEWAVE"); - std::cout << " UNIFORM GRID DIM : " << GlobalC::pw.nx <<" * " << GlobalC::pw.ny <<" * "<< GlobalC::pw.nz << std::endl; - std::cout << " UNIFORM GRID DIM(BIG): " << GlobalC::pw.nbx <<" * " << GlobalC::pw.nby <<" * "<< GlobalC::pw.nbz << std::endl; - - // the symmetry of a variety of systems. - if(GlobalV::CALCULATION == "test") - { - Cal_Test::test_memory(); - ModuleBase::QUIT(); - } - - // initialize the real-space uniform grid for FFT and parallel - // distribution of plane waves - GlobalC::Pgrid.init(GlobalC::pw.ncx, GlobalC::pw.ncy, GlobalC::pw.ncz, GlobalC::pw.nczp, - GlobalC::pw.nrxx, GlobalC::pw.nbz, GlobalC::pw.bz); // mohan add 2010-07-22, update 2011-05-04 - // Calculate Structure factor - GlobalC::pw.setup_structure_factor(); - - // Inititlize the charge density. - GlobalC::CHR.allocate(GlobalV::NSPIN, GlobalC::pw.nrxx, GlobalC::pw.ngmc); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running,"INIT CHARGE"); - - // Initializee the potential. - GlobalC::pot.allocate(GlobalC::pw.nrxx); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running,"INIT POTENTIAL"); - - - // Peize Lin add 2018-11-30 -#ifdef __MPI - if(GlobalV::CALCULATION=="nscf") - { - switch(GlobalC::exx_global.info.hybrid_type) - { - case Exx_Global::Hybrid_Type::HF: - case Exx_Global::Hybrid_Type::PBE0: - case Exx_Global::Hybrid_Type::HSE: - GlobalC::exx_global.info.set_xcfunc(GlobalC::xcf); - break; - } - } -#endif - -#ifdef __DEEPKS - //wenfei 2021-12-19 - //if we are performing DeePKS calculations, we need to load a model - if (GlobalV::out_descriptor) - { - if (GlobalV::deepks_scf) - { - // load the DeePKS model from deep neural network - GlobalC::ld.load_model(INPUT.model_file); - } - } -#endif - - if(GlobalV::CALCULATION=="md") - { - Run_MD_LCAO run_md_lcao(orb_con.ParaV); - run_md_lcao.opt_cell(orb_con); - } - else // cell relaxations - { - LOOP_cell lc(orb_con.ParaV); - //keep wfc_gamma or wfc_k remaining - lc.opt_cell(orb_con); - } - - ModuleBase::timer::tick("Run_lcao","lcao_line"); - return; -} + if (orb_con.setup_2d) + orb_con.setup_2d_division(GlobalV::ofs_running, GlobalV::ofs_warning); +} \ No newline at end of file diff --git a/source/run_lcao.h b/source/run_lcao.h index f391b996882..bcc4c0d91ce 100644 --- a/source/run_lcao.h +++ b/source/run_lcao.h @@ -9,6 +9,7 @@ #include "module_base/global_variable.h" #include "module_orbital/ORB_control.h" #include "input.h" +#include "module_ensolver/en_solver.h" class Run_lcao { @@ -19,8 +20,10 @@ class Run_lcao ~Run_lcao(); // perform Linear Combination of Atomic Orbitals (LCAO) calculations - static void lcao_line(void); + static void lcao_line(ModuleEnSover::En_Solver *p_ensolver); +private: + static void Init_Basis_lcao(ORB_control& orb_con, Input& inp, UnitCell_pseudo& ucell); }; #endif diff --git a/source/run_pw.cpp b/source/run_pw.cpp index 6b88884f384..cd526afb24b 100644 --- a/source/run_pw.cpp +++ b/source/run_pw.cpp @@ -1,21 +1,18 @@ #include "run_pw.h" #include "src_pw/global.h" -#include "src_pw/energy.h" -#include "input.h" -#include "src_io/optical.h" #include "src_io/cal_test.h" #include "src_io/winput.h" +#include "src_io/optical.h" #include "src_io/numerical_basis.h" #include "src_io/numerical_descriptor.h" #include "src_io/print_info.h" -#include "module_symmetry/symmetry.h" #include "src_ions/Cell_PW.h" #include "src_pw/run_md_pw.h" Run_pw::Run_pw(){} Run_pw::~Run_pw(){} -void Run_pw::plane_wave_line(void) +void Run_pw::plane_wave_line(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_pw","plane_wave_line"); ModuleBase::timer::tick("Run_pw","plane_wave_line"); @@ -28,44 +25,6 @@ void Run_pw::plane_wave_line(void) #else GlobalC::ucell.setup_cell( GlobalV::global_pseudo_dir, GlobalV::global_atom_card, GlobalV::ofs_running); #endif - //GlobalC::ucell.setup_cell( GlobalV::global_pseudo_dir , GlobalV::global_atom_card , GlobalV::ofs_running, GlobalV::NLOCAL, GlobalV::NBANDS); - - // setup GlobalV::NBANDS - // Yu Liu add 2021-07-03 - GlobalC::CHR.cal_nelec(); - - // mohan add 2010-09-06 - // Yu Liu move here 2021-06-27 - // because the number of element type - // will easily be ignored, so here - // I warn the user again for each type. - for(int it=0; itInit(INPUT, GlobalC::ucell); - // Calculate Structure factor - GlobalC::pw.setup_structure_factor(); - // cout<<"after pgrid init nrxx = "<Run(istep,GlobalC::ucell); + p_ensolver->cal_Energy(GlobalC::en); eiter = elec.iter; #ifdef __LCAO #ifdef __MPI @@ -182,12 +175,6 @@ void Ions::opt_ions_pw(void) elec.non_self_consistent(istep-1); eiter = elec.iter; } - // mohan added 2021-01-28, perform stochastic calculations - else if(GlobalV::CALCULATION=="scf-sto" || GlobalV::CALCULATION=="relax-sto" || GlobalV::CALCULATION=="md-sto") - { - elec_sto.scf_stochastic(istep-1); - eiter = elec_sto.iter; - } if(GlobalC::pot.out_potential == 2) { @@ -204,7 +191,7 @@ void Ions::opt_ions_pw(void) if (GlobalV::CALCULATION=="scf" || GlobalV::CALCULATION=="relax" || GlobalV::CALCULATION=="cell-relax") { - stop = this->after_scf(istep, force_step, stress_step); // pengfei Li 2018-05-14 + stop = this->after_scf(p_ensolver, istep, force_step, stress_step); // pengfei Li 2018-05-14 } time_t fend = time(NULL); @@ -255,20 +242,20 @@ void Ions::opt_ions_pw(void) return; } -bool Ions::after_scf(const int &istep, int &force_step, int &stress_step) +bool Ions::after_scf(ModuleEnSover::En_Solver *p_ensolver, const int &istep, int &force_step, int &stress_step) { ModuleBase::TITLE("Ions","after_scf"); //calculate and gather all parts of total ionic forces ModuleBase::matrix force; if(GlobalV::FORCE) { - this->gather_force_pw(force); + this->gather_force_pw(p_ensolver, force); } //calculate and gather all parts of stress ModuleBase::matrix stress; if(GlobalV::STRESS) { - this->gather_stress_pw(stress); + this->gather_stress_pw(p_ensolver, stress); } //stop in last step if(istep==GlobalV::NSTEP) @@ -303,18 +290,21 @@ bool Ions::after_scf(const int &istep, int &force_step, int &stress_step) return 1; } -void Ions::gather_force_pw(ModuleBase::matrix &force) +void Ions::gather_force_pw(ModuleEnSover::En_Solver *p_ensolver, ModuleBase::matrix &force) { ModuleBase::TITLE("Ions","gather_force_pw"); - Forces fcs; - fcs.init(force); + // Forces fcs; + // fcs.init(force); + p_ensolver->cal_Force(force); } -void Ions::gather_stress_pw(ModuleBase::matrix& stress) + +void Ions::gather_stress_pw(ModuleEnSover::En_Solver *p_ensolver, ModuleBase::matrix& stress) { ModuleBase::TITLE("Ions","gather_stress_pw"); //basic stress - Stress_PW ss; - ss.cal_stress(stress); + // Stress_PW ss; + // ss.cal_stress(stress); + p_ensolver->cal_Stress(stress); //external stress double unit_transform = 0.0; unit_transform = ModuleBase::RYDBERG_SI / pow(ModuleBase::BOHR_RADIUS_SI,3) * 1.0e-8; diff --git a/source/src_ions/ions.h b/source/src_ions/ions.h index f2b0aac6483..3baba13652c 100644 --- a/source/src_ions/ions.h +++ b/source/src_ions/ions.h @@ -9,6 +9,7 @@ #include "../src_pw/sto_elec.h" //mohan added 2021-01-28 #include "ions_move_methods.h" #include "lattice_change_methods.h" +#include "module_ensolver/en_solver.h" class Ions { @@ -18,7 +19,7 @@ class Ions Ions(){}; ~Ions(){}; - void opt_ions_pw(void); + void opt_ions_pw(ModuleEnSover::En_Solver *p_ensolver); private: @@ -41,9 +42,9 @@ class Ions Lattice_Change_Methods LCM; //seperate force_stress function first - bool after_scf(const int &istep, int &force_step, int &stress_step); - void gather_force_pw(ModuleBase::matrix &force); - void gather_stress_pw(ModuleBase::matrix& stress); + bool after_scf(ModuleEnSover::En_Solver *p_ensolver,const int &istep, int &force_step, int &stress_step); + void gather_force_pw(ModuleEnSover::En_Solver *p_ensolver, ModuleBase::matrix &force); + void gather_stress_pw(ModuleEnSover::En_Solver *p_ensolver, ModuleBase::matrix& stress); bool if_do_relax(); bool if_do_cellrelax(); bool do_relax(const int& istep, int& jstep, const ModuleBase::matrix& ionic_force, const double& total_energy); diff --git a/source/src_lcao/FORCE_STRESS.cpp b/source/src_lcao/FORCE_STRESS.cpp index 6b432d481bd..e7be8ea6dd1 100644 --- a/source/src_lcao/FORCE_STRESS.cpp +++ b/source/src_lcao/FORCE_STRESS.cpp @@ -14,8 +14,9 @@ double Force_Stress_LCAO::force_invalid_threshold_ev = 0.00; double Force_Stress_LCAO::output_acc = 1.0e-8; -Force_Stress_LCAO::Force_Stress_LCAO (){} -Force_Stress_LCAO::~Force_Stress_LCAO (){} +Force_Stress_LCAO::Force_Stress_LCAO(Record_adj& ra) : + RA(&ra){} +Force_Stress_LCAO::~Force_Stress_LCAO() {} #include "../src_pw/efield.h" void Force_Stress_LCAO::getForceStress( @@ -749,6 +750,7 @@ void Force_Stress_LCAO::calForceStressIntegralPart( flk.ftable_k( isforce, isstress, + *this->RA, lowf.wfc_k, loc, foverlap, diff --git a/source/src_lcao/FORCE_STRESS.h b/source/src_lcao/FORCE_STRESS.h index 1c6190d60ad..d4bc36f4218 100644 --- a/source/src_lcao/FORCE_STRESS.h +++ b/source/src_lcao/FORCE_STRESS.h @@ -22,11 +22,12 @@ class Force_Stress_LCAO public : - Force_Stress_LCAO (); + Force_Stress_LCAO (Record_adj &ra); ~Force_Stress_LCAO (); - private: - +private: + + Record_adj* RA; Force_LCAO_k flk; // Force_LCAO_gamma flg; Stress_Func sc_pw; diff --git a/source/src_lcao/FORCE_k.cpp b/source/src_lcao/FORCE_k.cpp index a62ce97802b..3c41072b157 100644 --- a/source/src_lcao/FORCE_k.cpp +++ b/source/src_lcao/FORCE_k.cpp @@ -22,6 +22,7 @@ Force_LCAO_k::~Force_LCAO_k () void Force_LCAO_k::ftable_k ( const bool isforce, const bool isstress, + Record_adj &ra, std::vector& wfc_k, Local_Orbital_Charge &loc, ModuleBase::matrix& foverlap, @@ -49,7 +50,7 @@ void Force_LCAO_k::ftable_k ( // calculate the energy density matrix // and the force related to overlap matrix and energy density matrix. - this->cal_foverlap_k(isforce, isstress, wfc_k, loc, foverlap, soverlap); + this->cal_foverlap_k(isforce, isstress, ra, wfc_k, loc, foverlap, soverlap); // calculate the density matrix double** dm2d = new double*[GlobalV::NSPIN]; @@ -60,11 +61,9 @@ void Force_LCAO_k::ftable_k ( } ModuleBase::Memory::record ("Force_LCAO_k", "dm2d", GlobalV::NSPIN*pv->nnr, "double"); - Record_adj RA; - RA.for_2d(*pv); - loc.cal_dm_R(loc.dm_k, RA, dm2d); + loc.cal_dm_R(loc.dm_k, ra, dm2d); - this->cal_ftvnl_dphi_k(dm2d, isforce, isstress, ftvnl_dphi, stvnl_dphi); + this->cal_ftvnl_dphi_k(dm2d, isforce, isstress, ra, ftvnl_dphi, stvnl_dphi); // --------------------------------------- @@ -237,6 +236,7 @@ void Force_LCAO_k::finish_k(void) void Force_LCAO_k::cal_foverlap_k( const bool isforce, const bool isstress, + Record_adj &ra, std::vector& wfc_k, Local_Orbital_Charge &loc, ModuleBase::matrix& foverlap, @@ -255,9 +255,6 @@ void Force_LCAO_k::cal_foverlap_k( edm2d[is] = new double[pv->nnr]; ModuleBase::GlobalFunc::ZEROS(edm2d[is], pv->nnr); } - - Record_adj RA; - RA.for_2d(*pv); //-------------------------------------------- // calculate the energy density matrix here. @@ -279,7 +276,7 @@ void Force_LCAO_k::cal_foverlap_k( wfc_k, edm_k); loc.cal_dm_R(edm_k, - RA, edm2d); + ra, edm2d); ModuleBase::timer::tick("Force_LCAO_k", "cal_edm_2d"); //-------------------------------------------- @@ -298,10 +295,10 @@ void Force_LCAO_k::cal_foverlap_k( for(int I1=0; I1na; ++I1) { const int start1 = GlobalC::ucell.itiaiw2iwt(T1,I1,0); - for (int cb = 0; cb < RA.na_each[iat]; ++cb) + for (int cb = 0; cb < ra.na_each[iat]; ++cb) { - const int T2 = RA.info[iat][cb][3]; - const int I2 = RA.info[iat][cb][4]; + const int T2 = ra.info[iat][cb][3]; + const int I2 = ra.info[iat][cb][4]; const int start2 = GlobalC::ucell.itiaiw2iwt(T2, I2, 0); Atom* atom2 = &GlobalC::ucell.atoms[T2]; @@ -374,7 +371,6 @@ void Force_LCAO_k::cal_foverlap_k( } delete[] edm2d; - RA.delete_grid();//xiaohui add 2015-02-04 ModuleBase::timer::tick("Force_LCAO_k","cal_foverlap_k"); return; } @@ -382,8 +378,9 @@ void Force_LCAO_k::cal_foverlap_k( void Force_LCAO_k::cal_ftvnl_dphi_k( double** dm2d, const bool isforce, - const bool isstress, - ModuleBase::matrix& ftvnl_dphi, + const bool isstress, + Record_adj &ra, + ModuleBase::matrix& ftvnl_dphi, ModuleBase::matrix& stvnl_dphi) { ModuleBase::TITLE("Force_LCAO_k","cal_ftvnl_dphi"); @@ -393,8 +390,6 @@ void Force_LCAO_k::cal_ftvnl_dphi_k( // get the adjacent atom's information. // GlobalV::ofs_running << " calculate the ftvnl_dphi_k force" << std::endl; - Record_adj RA; - RA.for_2d(*this->UHM->LM->ParaV); int irr = 0; for(int T1=0; T1UHM->LM->ParaV); + RA.for_2d(*this->UHM->LM->ParaV, GlobalV::GAMMA_ONLY_LOCAL); double *test; test = new double[GlobalV::NLOCAL * GlobalV::NLOCAL]; diff --git a/source/src_lcao/FORCE_k.h b/source/src_lcao/FORCE_k.h index 04297f8b46b..57a902d6881 100644 --- a/source/src_lcao/FORCE_k.h +++ b/source/src_lcao/FORCE_k.h @@ -27,6 +27,7 @@ class Force_LCAO_k : public Force_LCAO_gamma void ftable_k ( const bool isforce, const bool isstress, + Record_adj &ra, std::vector& wfc_k, Local_Orbital_Charge &loc, ModuleBase::matrix& foverlap, @@ -50,10 +51,11 @@ class Force_LCAO_k : public Force_LCAO_gamma void finish_k(void); // calculate the force due to < dphi | beta > < beta | phi > - void cal_ftvnl_dphi_k(double** dm2d, const bool isforce, const bool isstress, ModuleBase::matrix& ftvnl_dphi, ModuleBase::matrix& stvnl_dphi); + void cal_ftvnl_dphi_k(double** dm2d, const bool isforce, const bool isstress, Record_adj& ra, + ModuleBase::matrix& ftvnl_dphi, ModuleBase::matrix& stvnl_dphi); // calculate the overlap force - void cal_foverlap_k(const bool isforce, const bool isstress, std::vector& wfc_k, + void cal_foverlap_k(const bool isforce, const bool isstress, Record_adj &ra, std::vector& wfc_k, Local_Orbital_Charge& loc, ModuleBase::matrix& foverlap, ModuleBase::matrix& soverlap); // calculate the force due to < phi | Vlocal | dphi > diff --git a/source/src_lcao/LCAO_matrix.h b/source/src_lcao/LCAO_matrix.h index 30e619005e3..53fce2a94fa 100644 --- a/source/src_lcao/LCAO_matrix.h +++ b/source/src_lcao/LCAO_matrix.h @@ -5,7 +5,7 @@ #include "../module_base/global_variable.h" #include "../module_base/vector3.h" #include "../module_base/complexmatrix.h" -#include "../src_parallel/parallel_orbitals.h" +#include "../module_orbital/parallel_orbitals.h" // add by jingan for map<> in 2021-12-2, will be deleted in the future #include "../src_ri/abfs-vector3_order.h" diff --git a/source/src_lcao/LCAO_nnr.cpp b/source/src_lcao/LCAO_nnr.cpp index 222bcba181c..6561cb7d4e7 100644 --- a/source/src_lcao/LCAO_nnr.cpp +++ b/source/src_lcao/LCAO_nnr.cpp @@ -5,154 +5,6 @@ #ifdef __DEEPKS #include "../module_deepks/LCAO_deepks.h" #endif -//---------------------------- -// define a global class obj. -//---------------------------- - -// be called in LOOP_ions.cpp -void Parallel_Orbitals::cal_nnr() -{ - ModuleBase::TITLE("LCAO_nnr","cal_nnr"); - - delete[] nlocdim; - delete[] nlocstart; - nlocdim = new int[GlobalC::ucell.nat]; - nlocstart = new int[GlobalC::ucell.nat]; - ModuleBase::GlobalFunc::ZEROS(nlocdim, GlobalC::ucell.nat); - ModuleBase::GlobalFunc::ZEROS(nlocstart, GlobalC::ucell.nat); - - this->nnr = 0; - int start = 0; - //int ind1 = 0; - int iat = 0; - - // (1) find the adjacent atoms of atom[T1,I1]; - ModuleBase::Vector3 tau1; - ModuleBase::Vector3 tau2; - ModuleBase::Vector3 dtau; - ModuleBase::Vector3 tau0; - ModuleBase::Vector3 dtau1; - ModuleBase::Vector3 dtau2; - - for (int T1 = 0; T1 < GlobalC::ucell.ntype; T1++) - { - for (int I1 = 0; I1 < GlobalC::ucell.atoms[T1].na; I1++) - { - tau1 = GlobalC::ucell.atoms[T1].tau[I1]; - //GlobalC::GridD.Find_atom( tau1 ); - GlobalC::GridD.Find_atom(GlobalC::ucell, tau1 ,T1, I1); - const int start1 = GlobalC::ucell.itiaiw2iwt(T1, I1, 0); - this->nlocstart[iat] = nnr; - int nw1 = GlobalC::ucell.atoms[T1].nw * GlobalV::NPOL; - - // (2) search among all adjacent atoms. - for (int ad = 0; ad < GlobalC::GridD.getAdjacentNum()+1; ad++) - { - const int T2 = GlobalC::GridD.getType(ad); - const int I2 = GlobalC::GridD.getNatom(ad); - //const int iat2 = GlobalC::ucell.itia2iat(T2, I2); - const int start2 = GlobalC::ucell.itiaiw2iwt(T2, I2, 0); - int nw2 = GlobalC::ucell.atoms[T2].nw * GlobalV::NPOL; - - tau2 = GlobalC::GridD.getAdjacentTau(ad); - - dtau = tau2 - tau1; - double distance = dtau.norm() * GlobalC::ucell.lat0; - double rcut = GlobalC::ORB.Phi[T1].getRcut() + GlobalC::ORB.Phi[T2].getRcut(); - - if(distance < rcut) - { - //-------------------------------------------------- - // calculate how many matrix elements are in - // this processor. - for(int ii=0; iitrace_loc_row[iw1_all]; - if(mu<0)continue; - - for(int jj=0; jjtrace_loc_col[iw2_all]; - if(nu<0)continue; - - // orbital numbers for this atom (iat), - // seperated by atoms in different cells. - this->nlocdim[iat]++; - - ++nnr; - }// end jj - }// end ii - }//end distance - // there is another possibility that i and j are adjacent atoms. - // which is that are adjacents while are also - // adjacents, these considerations are only considered in k-point - // algorithm, - // mohan fix bug 2012-07-03 - else if(distance >= rcut) - { - for (int ad0 = 0; ad0 < GlobalC::GridD.getAdjacentNum()+1; ++ad0) - { - const int T0 = GlobalC::GridD.getType(ad0); - const int I0 = GlobalC::GridD.getNatom(ad0); - //const int iat0 = GlobalC::ucell.itia2iat(T0, I0); - //const int start0 = GlobalC::ucell.itiaiw2iwt(T0, I0, 0); - - tau0 = GlobalC::GridD.getAdjacentTau(ad0); - dtau1 = tau0 - tau1; - double distance1 = dtau1.norm() * GlobalC::ucell.lat0; - double rcut1 = GlobalC::ORB.Phi[T1].getRcut() + GlobalC::ucell.infoNL.Beta[T0].get_rcut_max(); - - dtau2 = tau0 - tau2; - double distance2 = dtau2.norm() * GlobalC::ucell.lat0; - double rcut2 = GlobalC::ORB.Phi[T2].getRcut() + GlobalC::ucell.infoNL.Beta[T0].get_rcut_max(); - - if( distance1 < rcut1 && distance2 < rcut2 ) - { - for(int ii=0; iitrace_loc_row[iw1_all]; - if(mu<0)continue; - - for(int jj=0; jjtrace_loc_col[iw2_all]; - if(nu<0)continue; - - // orbital numbers for this atom (iat), - // seperated by atoms in different cells. - this->nlocdim[iat]++; - - ++nnr; - } - } - break; - } // dis1, dis2 - }//ad0 - } - }// end ad - - //start position of atom[T1,I1] - start += nw1; - ++iat; - }// end I1 - } // end T1 - - //xiaohui add 'GlobalV::OUT_LEVEL' line, 2015-09-16 - if(GlobalV::OUT_LEVEL != "m") ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"nnr",nnr); -// for(int iat=0; iatLM); - ions.opt_ions(); + ions.opt_ions(p_ensolver); // mohan update 2021-02-10 orb_con.clear_after_ions(GlobalC::UOT, GlobalC::ORB, GlobalV::out_descriptor, GlobalC::ucell.infoNL.nproj); diff --git a/source/src_lcao/LOOP_cell.h b/source/src_lcao/LOOP_cell.h index a4f8dedd057..b5bc6ba153a 100644 --- a/source/src_lcao/LOOP_cell.h +++ b/source/src_lcao/LOOP_cell.h @@ -5,6 +5,7 @@ #include "module_base/complexmatrix.h" #include "module_orbital/ORB_control.h" #include "src_lcao/LCAO_matrix.h" +#include "module_ensolver/en_solver.h" class LOOP_cell { @@ -13,7 +14,7 @@ class LOOP_cell LOOP_cell(Parallel_Orbitals &pv); ~LOOP_cell(); - void opt_cell(ORB_control &orb_con); + void opt_cell(ORB_control &orb_con, ModuleEnSover::En_Solver *p_ensolver); private: LCAO_Matrix LM; diff --git a/source/src_lcao/LOOP_elec.cpp b/source/src_lcao/LOOP_elec.cpp index 010a523a546..48f00c39416 100644 --- a/source/src_lcao/LOOP_elec.cpp +++ b/source/src_lcao/LOOP_elec.cpp @@ -26,6 +26,7 @@ #endif void LOOP_elec::solve_elec_stru(const int& istep, + Record_adj &ra, Local_Orbital_Charge& loc, Local_Orbital_wfc& lowf, LCAO_Hamilt& uhm_in) @@ -36,7 +37,7 @@ void LOOP_elec::solve_elec_stru(const int& istep, this->UHM = &uhm_in; // prepare HS matrices, prepare grid integral - this->set_matrix_grid(); + this->set_matrix_grid(ra); // density matrix extrapolation and prepare S,T,VNL matrices this->before_solver(istep, loc, lowf); // do self-interaction calculations / nscf/ tddft, etc. @@ -47,7 +48,7 @@ void LOOP_elec::solve_elec_stru(const int& istep, } -void LOOP_elec::set_matrix_grid(void) +void LOOP_elec::set_matrix_grid(Record_adj &ra) { ModuleBase::TITLE("LOOP_elec","set_matrix_grid"); ModuleBase::timer::tick("LOOP_elec","set_matrix_grid"); @@ -77,13 +78,13 @@ void LOOP_elec::set_matrix_grid(void) GlobalC::pw.nbx, GlobalC::pw.nby, GlobalC::pw.nbz, GlobalC::pw.nbxx, GlobalC::pw.nbzp_start, GlobalC::pw.nbzp); - // (2) If k point is used here, allocate HlocR after atom_arrange. + // (2)For each atom, calculate the adjacent atoms in different cells + // and allocate the space for H(R) and S(R). + // If k point is used here, allocate HlocR after atom_arrange. + Parallel_Orbitals* pv = this->UHM->LM->ParaV; + ra.for_2d(*pv, GlobalV::GAMMA_ONLY_LOCAL); if(!GlobalV::GAMMA_ONLY_LOCAL) { - // For each atom, calculate the adjacent atoms in different cells - // and allocate the space for H(R) and S(R). - Parallel_Orbitals* pv = this->UHM->LM->ParaV; - pv->cal_nnr(); this->UHM->LM->allocate_HS_R(pv->nnr); #ifdef __DEEPKS GlobalC::ld.allocate_V_deltaR(pv->nnr); diff --git a/source/src_lcao/LOOP_elec.h b/source/src_lcao/LOOP_elec.h index 2c4ef229808..dbf1538f3bc 100644 --- a/source/src_lcao/LOOP_elec.h +++ b/source/src_lcao/LOOP_elec.h @@ -18,6 +18,7 @@ class LOOP_elec // mohan add 2021-02-09 void solve_elec_stru(const int& istep, + Record_adj &ra, Local_Orbital_Charge& loc, Local_Orbital_wfc& low, LCAO_Hamilt& uhm_in); @@ -25,7 +26,7 @@ class LOOP_elec private: // set matrix and grid integral - void set_matrix_grid(void); + void set_matrix_grid(Record_adj &ra); void before_solver(const int& istep, Local_Orbital_Charge& loc, diff --git a/source/src_lcao/LOOP_ions.cpp b/source/src_lcao/LOOP_ions.cpp index 86d00928737..02efbcce737 100644 --- a/source/src_lcao/LOOP_ions.cpp +++ b/source/src_lcao/LOOP_ions.cpp @@ -1,6 +1,6 @@ #include "LOOP_ions.h" #include "../src_pw/global.h" -#include "../src_parallel/parallel_orbitals.h" +#include "../module_orbital/parallel_orbitals.h" #include "../src_pdiag/pdiag_double.h" #include "FORCE_STRESS.h" #include "../module_base/global_function.h" @@ -35,7 +35,7 @@ LOOP_ions::LOOP_ions(LCAO_Matrix &lm) LOOP_ions::~LOOP_ions() {} -void LOOP_ions::opt_ions() +void LOOP_ions::opt_ions(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("LOOP_ions","opt_ions"); ModuleBase::timer::tick("LOOP_ions","opt_ions"); @@ -149,11 +149,10 @@ void LOOP_ions::opt_ions() GlobalC::en.evdw = vdwd3.get_energy(); } - + Record_adj RA; // solve electronic structures in terms of LCAO - // mohan add 2021-02-09 - LOE.solve_elec_stru(this->istep, this->LOC, this->LOWF, this->UHM); - + // mohan add 2021-02-09 + p_ensolver->Run(this->istep, RA, this->LOC, this->LOWF, this->UHM); time_t eend = time(NULL); @@ -234,10 +233,10 @@ void LOOP_ions::opt_ions() time_t fstart = time(NULL); if (GlobalV::CALCULATION=="scf" || GlobalV::CALCULATION=="relax" || GlobalV::CALCULATION=="cell-relax") { - stop = this->force_stress(istep, force_step, stress_step); + stop = this->force_stress(istep, force_step, stress_step, RA); } time_t fend = time(NULL); - + RA.delete_grid(); // PLEASE move the details of CE to other places // mohan add 2021-03-25 //xiaohui add 2014-07-07, for second-order extrapolation @@ -299,7 +298,8 @@ void LOOP_ions::opt_ions() bool LOOP_ions::force_stress( const int &istep, int &force_step, - int &stress_step) + int& stress_step, + Record_adj &ra) { ModuleBase::TITLE("LOOP_ions","force_stress"); @@ -313,7 +313,7 @@ bool LOOP_ions::force_stress( ModuleBase::matrix fcs; // set stress matrix ModuleBase::matrix scs; - Force_Stress_LCAO FSL; + Force_Stress_LCAO FSL(ra); FSL.getForceStress(GlobalV::FORCE, GlobalV::STRESS, GlobalV::TEST_FORCE, GlobalV::TEST_STRESS, this->LOC, this->LOWF, this->UHM, fcs, scs); @@ -535,12 +535,13 @@ void LOOP_ions::final_scf(void) GlobalC::pw.nbxx, GlobalC::pw.nbzp_start, GlobalC::pw.nbzp); // (2) If k point is used here, allocate HlocR after atom_arrange. - if(!GlobalV::GAMMA_ONLY_LOCAL) + Parallel_Orbitals* pv = this->UHM.LM->ParaV; + Record_adj RA; + RA.for_2d(*pv, GlobalV::GAMMA_ONLY_LOCAL); + if (!GlobalV::GAMMA_ONLY_LOCAL) { // For each atom, calculate the adjacent atoms in different cells // and allocate the space for H(R) and S(R). - Parallel_Orbitals* pv = this->UHM.LM->ParaV; - pv->cal_nnr(); this->UHM.LM->allocate_HS_R(pv->nnr); #ifdef __DEEPKS GlobalC::ld.allocate_V_deltaR(pv->nnr); diff --git a/source/src_lcao/LOOP_ions.h b/source/src_lcao/LOOP_ions.h index 3e548b7ee47..a9ec64622e1 100644 --- a/source/src_lcao/LOOP_ions.h +++ b/source/src_lcao/LOOP_ions.h @@ -8,6 +8,7 @@ #include "src_lcao/local_orbital_wfc.h" #include "module_orbital/ORB_control.h" #include "src_lcao/LCAO_hamilt.h" +#include "module_ensolver/en_solver.h" #include @@ -24,7 +25,7 @@ class LOOP_ions Local_Orbital_Charge LOC; LCAO_Hamilt UHM; - void opt_ions(); //output for dos + void opt_ions(ModuleEnSover::En_Solver *p_ensolver); //output for dos void output_HS_R( const std::string &SR_filename="data-SR-sparse_SPIN0.csr", const std::string &HR_filename_up="data-HR-sparse_SPIN0.csr", @@ -46,7 +47,7 @@ class LOOP_ions // the renew of structure factors, etc. should be ran in other places // the 'IMM' and 'LCM' objects should be passed to force_stress() via parameters list // mohan note 2021-03-23 - bool force_stress(const int &istep, int &force_step, int &stress_step); + bool force_stress(const int &istep, int &force_step, int &stress_step, Record_adj &ra); int istep; diff --git a/source/src_lcao/center2_orb-orb11.cpp b/source/src_lcao/center2_orb-orb11.cpp index 5b24f12f132..34311cbdd53 100644 --- a/source/src_lcao/center2_orb-orb11.cpp +++ b/source/src_lcao/center2_orb-orb11.cpp @@ -151,3 +151,93 @@ double Center2_Orb::Orb11::cal_overlap( return overlap; } + +ModuleBase::Vector3 Center2_Orb::Orb11::cal_grad_overlap( //caoyu add 2021-11-19 + const ModuleBase::Vector3 &RA, const ModuleBase::Vector3 &RB, + const int& mA, const int& mB) const +{ + const double tiny1 = 1e-12; // same as `cal_overlap` + const double tiny2 = 1e-10; //same as `cal_overlap` + + const ModuleBase::Vector3 delta_R = RB-RA; + const double distance_true = delta_R.norm(); + const double distance = (distance_true>=tiny1) ? distance_true : distance_true+tiny1; + const double RcutA = nA.getRcut(); + const double RcutB = nB.getRcut(); + if( distance > (RcutA + RcutB) ) + return ModuleBase::Vector3(0.0, 0.0, 0.0); + + const int LA = nA.getL(); + const int LB = nB.getL(); + + std::vector rly; + std::vector> tmp_grly; + std::vector> grly; + ModuleBase::Ylm::grad_rl_sph_harm( + LA + LB, + delta_R.x, delta_R.y, delta_R.z, + rly, tmp_grly); + for (const auto& tmp_ele : tmp_grly) + { + ModuleBase::Vector3 ele(tmp_ele[0], tmp_ele[1], tmp_ele[2]); + grly.push_back(ele); + } + + ModuleBase::Vector3 grad_overlap(0.0, 0.0, 0.0); + + for (const auto& tb_r : Table_r) + { + const int LAB = tb_r.first; + for( int mAB=0; mAB!=2*LAB+1; ++mAB ) + // const int mAB = mA + mB; + { + const double Gaunt_real_A_B_AB = + MGT.Gaunt_Coefficients ( + MGT.get_lm_index(LA,mA), + MGT.get_lm_index(LB,mB), + MGT.get_lm_index(LAB,mAB)); + if( 0==Gaunt_real_A_B_AB ) continue; + + const double ylm_solid = rly[ MGT.get_lm_index(LAB, mAB) ]; + const double ylm_real = + (distance > tiny2) ? + ylm_solid / pow(distance,LAB) : + ylm_solid; + + const ModuleBase::Vector3 gylm_solid = grly[MGT.get_lm_index(LAB, mAB)]; + const ModuleBase::Vector3 gylm_real = + (distance > tiny2) ? + gylm_solid / pow(distance,LAB) : + gylm_solid; + + const double i_exp = std::pow(-1.0, (LA - LB - LAB) / 2); + + const double Interp_Tlm = + (distance > tiny2) ? + ModuleBase::PolyInt::Polynomial_Interpolation( + ModuleBase::GlobalFunc::VECTOR_TO_PTR(tb_r.second), + MOT.get_rmesh(RcutA, RcutB), + MOT.dr, + distance ) : + tb_r.second.at(0); + + const double grad_Interp_Tlm = + (distance > tiny2) ? + ModuleBase::PolyInt::Polynomial_Interpolation( + ModuleBase::GlobalFunc::VECTOR_TO_PTR(this->Table_dr.at(LAB)), + MOT.get_rmesh(RcutA, RcutB), + MOT.dr, + distance) //Interp(Table_dr) + - Interp_Tlm * LAB / distance : + 0.0; + + grad_overlap += + i_exp // pow(2*PI,1.5) + * Gaunt_real_A_B_AB + * (Interp_Tlm * gylm_real + + grad_Interp_Tlm * ylm_real *delta_R / distance); + } + } + + return grad_overlap; +} \ No newline at end of file diff --git a/source/src_lcao/center2_orb-orb11.h b/source/src_lcao/center2_orb-orb11.h index d35b2efd22a..6266ee82772 100644 --- a/source/src_lcao/center2_orb-orb11.h +++ b/source/src_lcao/center2_orb-orb11.h @@ -34,8 +34,12 @@ class Center2_Orb::Orb11 double cal_overlap( const ModuleBase::Vector3 &RA, const ModuleBase::Vector3 &RB, // unit: Bohr const int &mA, const int &mB) const; - - private: + + ModuleBase::Vector3 cal_grad_overlap( //caoyu add 2021-11-19 + const ModuleBase::Vector3 &RA, const ModuleBase::Vector3 &RB, // unit: Bohr + const int& mA, const int& mB) const; + +private: const Numerical_Orbital_Lm &nA; const Numerical_Orbital_Lm &nB; diff --git a/source/src_lcao/dftu.h b/source/src_lcao/dftu.h index 6e51f1cba86..29fc02e06b7 100644 --- a/source/src_lcao/dftu.h +++ b/source/src_lcao/dftu.h @@ -11,7 +11,7 @@ #include "dftu_relax.h" #include "../module_cell/unitcell_pseudo.h" -#include "../src_parallel/parallel_orbitals.h" +#include "../module_orbital/parallel_orbitals.h" using namespace std; @@ -62,10 +62,6 @@ class DFTU : public DFTU_RELAX double EU; int iter_dftu; - -private: - LCAO_Matrix* LM; - }; } namespace GlobalC diff --git a/source/src_lcao/dftu_relax.h b/source/src_lcao/dftu_relax.h index 7d49e4d3b14..d0672304af8 100644 --- a/source/src_lcao/dftu_relax.h +++ b/source/src_lcao/dftu_relax.h @@ -64,7 +64,7 @@ class DFTU_RELAX : public DFTU_Yukawa //locale_save: the input local occupation number matrix of correlated electrons in the current electronic step std::vector>>> locale; // locale[iat][l][n][spin](m1,m2) std::vector>>> locale_save; // locale_save[iat][l][n][spin](m1,m2) -private: +protected: LCAO_Matrix* LM; }; } diff --git a/source/src_lcao/record_adj.cpp b/source/src_lcao/record_adj.cpp index 966383793d9..ff6aecb1765 100644 --- a/source/src_lcao/record_adj.cpp +++ b/source/src_lcao/record_adj.cpp @@ -26,14 +26,26 @@ void Record_adj::delete_grid(void) //-------------------------------------------- // This will record the orbitals according to // HPSEPS's 2D block division. +// If multi-k, calculate nnr at the same time. +// be called only once in an ion-step. //-------------------------------------------- -void Record_adj::for_2d(const Parallel_Orbitals &pv) +void Record_adj::for_2d(Parallel_Orbitals &pv, bool gamma_only) { ModuleBase::TITLE("Record_adj","for_2d"); ModuleBase::timer::tick("Record_adj","for_2d"); - assert(GlobalC::ucell.nat>0); - + assert(GlobalC::ucell.nat > 0); + if (!gamma_only) + { + delete[] pv.nlocdim; + delete[] pv.nlocstart; + pv.nlocdim = new int[GlobalC::ucell.nat]; + pv.nlocstart = new int[GlobalC::ucell.nat]; + ModuleBase::GlobalFunc::ZEROS(pv.nlocdim, GlobalC::ucell.nat); + ModuleBase::GlobalFunc::ZEROS(pv.nlocstart, GlobalC::ucell.nat); + pv.nnr = 0; + } + // (1) find the adjacent atoms of atom[T1,I1]; ModuleBase::Vector3 tau1, tau2, dtau; ModuleBase::Vector3 dtau1, dtau2, tau0; @@ -44,8 +56,6 @@ void Record_adj::for_2d(const Parallel_Orbitals &pv) this->na_each = new int[na_proc]; ModuleBase::GlobalFunc::ZEROS(na_each, na_proc); int iat = 0; - int irr = 0; - // std::cout << " in for_2d" << std::endl; @@ -58,8 +68,9 @@ void Record_adj::for_2d(const Parallel_Orbitals &pv) //GlobalC::GridD.Find_atom( tau1 ); GlobalC::GridD.Find_atom(GlobalC::ucell, tau1 ,T1, I1); const int start1 = GlobalC::ucell.itiaiw2iwt(T1, I1, 0); - - // (2) search among all adjacent atoms. + if(!gamma_only) pv.nlocstart[iat] = pv.nnr; + + // (2) search among all adjacent atoms. for (int ad = 0; ad < GlobalC::GridD.getAdjacentNum()+1; ++ad) { const int T2 = GlobalC::GridD.getType(ad); @@ -71,8 +82,12 @@ void Record_adj::for_2d(const Parallel_Orbitals &pv) double rcut = GlobalC::ORB.Phi[T1].getRcut() + GlobalC::ORB.Phi[T2].getRcut(); bool is_adj = false; - if( distance < rcut) is_adj = true; - else if( distance >= rcut) + if (distance < rcut) is_adj = true; + // there is another possibility that i and j are adjacent atoms. + // which is that are adjacents while are also + // adjacents, these considerations are only considered in k-point + // algorithm, + else if (distance >= rcut) { for (int ad0 = 0; ad0 < GlobalC::GridD.getAdjacentNum()+1; ++ad0) { @@ -98,39 +113,37 @@ void Record_adj::for_2d(const Parallel_Orbitals &pv) } } - - if(is_adj) { - ++na_each[iat]; - - for(int ii=0; iinw * GlobalV::NPOL; ++ii) - { - // the index of orbitals in this processor - const int iw1_all = start1 + ii; - const int mu = pv.trace_loc_row[iw1_all]; - if(mu<0)continue; + ++na_each[iat]; + if (!gamma_only) + { + for(int ii=0; iinw * GlobalV::NPOL; ++ii) + { + // the index of orbitals in this processor + const int iw1_all = start1 + ii; + const int mu = pv.trace_loc_row[iw1_all]; + if(mu<0)continue; - for(int jj=0; jj 10|| I2<0) exit(-1); + ++cb; } }//end ad // GlobalV::ofs_running << " nadj = " << cb << std::endl; diff --git a/source/src_lcao/record_adj.h b/source/src_lcao/record_adj.h index 6290f7ada07..05eb8c03722 100644 --- a/source/src_lcao/record_adj.h +++ b/source/src_lcao/record_adj.h @@ -2,7 +2,7 @@ #define RECORD_ADJ_H #include "grid_technique.h" -#include "src_parallel/parallel_orbitals.h" +#include "module_orbital/parallel_orbitals.h" //--------------------------------------------------- @@ -19,7 +19,7 @@ class Record_adj // This will record the orbitals according to // HPSEPS's 2D block division. //-------------------------------------------- - void for_2d(const Parallel_Orbitals &pv); + void for_2d(Parallel_Orbitals &pv, bool gamma_only); //-------------------------------------------- // This will record the orbitals according to diff --git a/source/src_lcao/run_md_lcao.cpp b/source/src_lcao/run_md_lcao.cpp index 18e83353aa7..10575077e24 100644 --- a/source/src_lcao/run_md_lcao.cpp +++ b/source/src_lcao/run_md_lcao.cpp @@ -5,7 +5,7 @@ #include "../src_pw/vdwd2.h" #include "../src_pw/vdwd2_parameters.h" #include "../src_pw/vdwd3_parameters.h" -#include "../src_parallel/parallel_orbitals.h" +#include "../module_orbital/parallel_orbitals.h" #include "../src_pdiag/pdiag_double.h" #include "../src_io/write_HS.h" #include "../src_io/cal_r_overlap_R.h" @@ -34,7 +34,7 @@ Run_MD_LCAO::Run_MD_LCAO(Parallel_Orbitals &pv) Run_MD_LCAO::~Run_MD_LCAO(){} -void Run_MD_LCAO::opt_cell(ORB_control &orb_con) +void Run_MD_LCAO::opt_cell(ORB_control &orb_con, ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_MD_LCAO","opt_cell"); @@ -58,16 +58,14 @@ void Run_MD_LCAO::opt_cell(ORB_control &orb_con) int ion_step=0; GlobalC::pot.init_pot(ion_step, GlobalC::pw.strucFac); - - opt_ions(); - + opt_ions(p_ensolver); orb_con.clear_after_ions(GlobalC::UOT, GlobalC::ORB, GlobalV::out_descriptor, GlobalC::ucell.infoNL.nproj); return; } -void Run_MD_LCAO::opt_ions(void) +void Run_MD_LCAO::opt_ions(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_MD_LCAO","opt_ions"); ModuleBase::timer::tick("Run_MD_LCAO","opt_ions"); @@ -121,7 +119,7 @@ void Run_MD_LCAO::opt_ions(void) if(verlet->step_ == 0) { MD_func::ParaV = this->LM_md.ParaV; - verlet->setup(); + verlet->setup(p_ensolver); } else { @@ -150,7 +148,7 @@ void Run_MD_LCAO::opt_ions(void) GlobalC::pot.init_pot(verlet->step_, GlobalC::pw.strucFac); // update force and virial due to the update of atom positions - MD_func::force_virial(verlet->step_, verlet->mdp, verlet->ucell, verlet->potential, verlet->force, verlet->virial); + MD_func::force_virial(p_ensolver, verlet->step_, verlet->mdp, verlet->ucell, verlet->potential, verlet->force, verlet->virial); verlet->second_half(); @@ -204,6 +202,7 @@ void Run_MD_LCAO::opt_ions(void) } void Run_MD_LCAO::md_force_virial( + ModuleEnSover::En_Solver *p_ensolver, const int &istep, const int& numIon, double &potential, @@ -255,16 +254,19 @@ void Run_MD_LCAO::md_force_virial( // mohan add 2021-02-09 LCAO_Hamilt UHM_md; UHM_md.genH.LM = UHM_md.LM = &this->LM_md; + + Record_adj RA_md; + LOOP_elec LOE; - LOE.solve_elec_stru(istep + 1, LOC_md, LOWF_md, UHM_md); + LOE.solve_elec_stru(istep + 1, RA_md, LOC_md, LOWF_md, UHM_md); //to call the force of each atom ModuleBase::matrix fcs;//temp force matrix - Force_Stress_LCAO FSL; + Force_Stress_LCAO FSL(RA_md); FSL.getForceStress(GlobalV::FORCE, GlobalV::STRESS, GlobalV::TEST_FORCE, GlobalV::TEST_STRESS, LOC_md, LOWF_md, UHM_md, fcs, virial); - + RA_md.delete_grid(); for(int ion=0; ion* force, diff --git a/source/src_parallel/CMakeLists.txt b/source/src_parallel/CMakeLists.txt index 0eb4a88689f..a3b9122edd0 100644 --- a/source/src_parallel/CMakeLists.txt +++ b/source/src_parallel/CMakeLists.txt @@ -4,7 +4,6 @@ list(APPEND objects parallel_global.cpp parallel_grid.cpp parallel_kpoints.cpp - parallel_orbitals.cpp parallel_pw.cpp parallel_reduce.cpp subgrid_oper.cpp diff --git a/source/src_pdiag/CMakeLists.txt b/source/src_pdiag/CMakeLists.txt index 18244f7adbf..d1844471936 100644 --- a/source/src_pdiag/CMakeLists.txt +++ b/source/src_pdiag/CMakeLists.txt @@ -4,7 +4,6 @@ add_library( pdiag OBJECT pdgseps.cpp - pdiag_basic.cpp pdiag_common.cpp pdiag_double.cpp pdst2g.cpp diff --git a/source/src_pdiag/pdiag_basic.cpp b/source/src_pdiag/pdiag_basic.cpp deleted file mode 100644 index dc5ec2007b2..00000000000 --- a/source/src_pdiag/pdiag_basic.cpp +++ /dev/null @@ -1,1130 +0,0 @@ -#include "../src_parallel/parallel_common.h" -#include "src_parallel/parallel_orbitals.h" -#include "src_pdiag/pdiag_double.h" -#include "../src_pw/global.h" -#include "../src_io/wf_local.h" -#include "../module_base/lapack_connector.h" -#include "../module_base/memory.h" - - -void ORB_control::set_parameters(void) -{ - ModuleBase::TITLE("ORB_control","set_parameters"); - - Parallel_Orbitals* pv = &this->ParaV; - // set loc_size - if(GlobalV::GAMMA_ONLY_LOCAL)//xiaohui add 2014-12-21 - { - pv->loc_size=GlobalV::NBANDS/GlobalV::DSIZE; - - // mohan add 2012-03-29 - if(pv->loc_size==0) - { - GlobalV::ofs_warning << " loc_size=0" << " in proc " << GlobalV::MY_RANK+1 << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < GlobalV::DSIZE"); - } - - if (GlobalV::DRANKloc_size+=1; - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"local size",pv->loc_size); - - // set loc_sizes - delete[] pv->loc_sizes; - pv->loc_sizes = new int[GlobalV::DSIZE]; - ModuleBase::GlobalFunc::ZEROS(pv->loc_sizes, GlobalV::DSIZE); - - pv->lastband_in_proc = 0; - pv->lastband_number = 0; - int count_bands = 0; - for (int i=0; iloc_sizes[i]=GlobalV::NBANDS/GlobalV::DSIZE+1; - } - else - { - pv->loc_sizes[i]=GlobalV::NBANDS/GlobalV::DSIZE; - } - count_bands += pv->loc_sizes[i]; - if (count_bands >= GlobalV::NBANDS) - { - pv->lastband_in_proc = i; - pv->lastband_number = GlobalV::NBANDS - (count_bands - pv->loc_sizes[i]); - break; - } - } - } - else - { - pv->loc_size=GlobalV::NLOCAL/GlobalV::DSIZE; - - // mohan add 2012-03-29 - if(pv->loc_size==0) - { - GlobalV::ofs_warning << " loc_size=0" << " in proc " << GlobalV::MY_RANK+1 << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::set_parameters","NLOCAL < GlobalV::DSIZE"); - } - - if (GlobalV::DRANKloc_size += 1; - } - if(pv->testpb) ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"local size",pv->loc_size); - - // set loc_sizes - delete[] pv->loc_sizes; - pv->loc_sizes = new int[GlobalV::DSIZE]; - ModuleBase::GlobalFunc::ZEROS(pv->loc_sizes, GlobalV::DSIZE); - - pv->lastband_in_proc = 0; - pv->lastband_number = 0; - int count_bands = 0; - for (int i=0; iloc_sizes[i]=GlobalV::NLOCAL/GlobalV::DSIZE+1; - } - else - { - pv->loc_sizes[i]=GlobalV::NLOCAL/GlobalV::DSIZE; - } - count_bands += pv->loc_sizes[i]; - if (count_bands >= GlobalV::NBANDS) - { - pv->lastband_in_proc = i; - pv->lastband_number = GlobalV::NBANDS - (count_bands - pv->loc_sizes[i]); - break; - } - } - }//xiaohui add 2014-12-21 - - if (GlobalV::KS_SOLVER=="hpseps") //LiuXh add 2021-09-06, clear memory, Z_LOC only used in hpseps solver - { - pv->Z_LOC = new double*[GlobalV::NSPIN]; - for(int is=0; isZ_LOC[is] = new double[pv->loc_size * GlobalV::NLOCAL]; - ModuleBase::GlobalFunc::ZEROS(pv->Z_LOC[is], pv->loc_size * GlobalV::NLOCAL); - } - pv->alloc_Z_LOC = true;//xiaohui add 2014-12-22 - } - - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"lastband_in_proc", pv->lastband_in_proc); - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"lastband_number", pv->lastband_number); - - return; -} - - -#ifdef __MPI -// creat the 'comm_2D' stratege. -void ORB_control::mpi_creat_cart(MPI_Comm *comm_2D, int prow, int pcol) -{ - ModuleBase::TITLE("ORB_control","mpi_creat_cart"); - // the matrix is divided as ( dim[0] * dim[1] ) - int dim[2]; - int period[2]={1,1}; - int reorder=0; - dim[0]=prow; - dim[1]=pcol; - - if(this->ParaV.testpb)GlobalV::ofs_running << " dim = " << dim[0] << " * " << dim[1] << std::endl; - - MPI_Cart_create(DIAG_WORLD,2,dim,period,reorder,comm_2D); - return; -} -#endif - -#ifdef __MPI -void ORB_control::mat_2d(MPI_Comm vu, - const int &M_A, - const int &N_A, - const int &nb, - LocalMatrix &LM) -{ - ModuleBase::TITLE("ORB_control", "mat_2d"); - - Parallel_Orbitals* pv = &this->ParaV; - - int dim[2]; - int period[2]; - int coord[2]; - int i,j,k,end_id; - int block; - - // (0) every processor get it's id on the 2D comm - // : ( coord[0], coord[1] ) - MPI_Cart_get(vu,2,dim,period,coord); - - // (1.1) how many blocks at least - // eg. M_A = 6400, nb = 64; - // so block = 10; - block=M_A/nb; - - // (1.2) If data remain, add 1. - if (block*nbtestpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Total Row Blocks Number",block); - - // mohan add 2010-09-12 - if(dim[0]>block) - { - GlobalV::ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; - GlobalV::ofs_warning << " but, the number of row blocks is " << block << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no row blocks, try a smaller 'nb2d' parameter."); - } - - // (2.1) row_b : how many blocks for this processor. (at least) - LM.row_b=block/dim[0]; - - // (2.2) row_b : how many blocks in this processor. - // if there are blocks remain, some processors add 1. - if (coord[0]testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local Row Block Number",LM.row_b); - - // (3) end_id indicates the last block belong to - // which processor. - if (block%dim[0]==0) - { - end_id=dim[0]-1; - } - else - { - end_id=block%dim[0]-1; - } - - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Ending Row Block in processor",end_id); - - // (4) row_num : how many rows in this processors : - // the one owns the last block is different. - if (coord[0]==end_id) - { - LM.row_num=(LM.row_b-1)*nb+(M_A-(block-1)*nb); - } - else - { - LM.row_num=LM.row_b*nb; - } - - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local rows (including nb)",LM.row_num); - - // (5) row_set, it's a global index : - // save explicitly : every row in this processor - // belongs to which row in the global matrix. - delete[] LM.row_set; - LM.row_set= new int[LM.row_num]; - j=0; - for (i=0; iblock) - { - GlobalV::ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; - GlobalV::ofs_warning << " but, the number of column blocks is " << block << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no column blocks."); - } - - LM.col_b=block/dim[1]; - if (coord[1]testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local Row Block Number",LM.col_b); - - if (block%dim[1]==0) - { - end_id=dim[1]-1; - } - else - { - end_id=block%dim[1]-1; - } - - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Ending Row Block in processor",end_id); - - if (coord[1]==end_id) - { - LM.col_num=(LM.col_b-1)*nb+(M_A-(block-1)*nb); - } - else - { - LM.col_num=LM.col_b*nb; - } - - if(pv->testpb)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"Local columns (including nb)",LM.row_num); - - delete[] LM.col_set; - LM.col_set = new int[LM.col_num]; - - j=0; - for (i=0; iblock) - { - GlobalV::ofs_warning << " cpu 2D distribution : " << dim[0] << "*" << dim[1] << std::endl; - GlobalV::ofs_warning << " but, the number of bands-row-block is " << block << std::endl; - ModuleBase::WARNING_QUIT("ORB_control::mat_2d","some processor has no bands-row-blocks."); - } - int col_b_bands = block / dim[1]; - if (coord[1] < block % dim[1]) - { - col_b_bands++; - } - if (block%dim[1]==0) - { - end_id=dim[1]-1; - } - else - { - end_id=block%dim[1]-1; - } - if (coord[1]==end_id) - { - pv->ncol_bands=(col_b_bands-1)*nb+(N_A-(block-1)*nb); - } - else - { - pv->ncol_bands=col_b_bands*nb; - } - pv->nloc_wfc = pv->ncol_bands * LM.row_num; - - std::cout << pv->nloc_wfc << " " << pv->ncol_bands << " " << LM.row_num << std::endl; - - return; -} -#endif - - -#ifdef __MPI -// A : contains total matrix element in processor. -void ORB_control::data_distribution( - MPI_Comm comm_2D, - const std::string &file, - const int &n, - const int &nb, - double *A, - const LocalMatrix &LM) -{ - ModuleBase::TITLE("ORB_control", "data_distribution"); - Parallel_Orbitals* pv = &this->ParaV; - MPI_Comm comm_row; - MPI_Comm comm_col; - MPI_Status status; - - int dim[2]; - int period[2]; - int coord[2]; - MPI_Cart_get(comm_2D,2,dim,period,coord); - - if(pv->testpb) GlobalV::ofs_running << "\n dim = " << dim[0] << " * " << dim[1] << std::endl; - if(pv->testpb) GlobalV::ofs_running << " coord = ( " << coord[0] << " , " << coord[1] << ")." << std::endl; - if(pv->testpb) GlobalV::ofs_running << " n = " << n << std::endl; - - mpi_sub_col(comm_2D,&comm_col); - mpi_sub_row(comm_2D,&comm_row); - - // total number of processors - const int myid = coord[0]*dim[1]+coord[1]; - - // the matrix is n * n - double* ele_val = new double[n]; - double* val = new double[n]; - int* sends = new int[dim[1]]; - int* fpt = new int[dim[1]]; - int* snd = new int[dim[1]]; - int* temp = new int[dim[1]]; - - ModuleBase::GlobalFunc::ZEROS(ele_val, n); - ModuleBase::GlobalFunc::ZEROS(val, n); - ModuleBase::GlobalFunc::ZEROS(sends, dim[1]); - ModuleBase::GlobalFunc::ZEROS(fpt, dim[1]); - ModuleBase::GlobalFunc::ZEROS(snd, dim[1]); - ModuleBase::GlobalFunc::ZEROS(temp, dim[1]); - - // the columes of matrix is divided by 'dim[1]' 'rows of processors'. - // collect all information of each 'rows of processors' - // collection data is saved in 'sends' - snd[coord[1]] = LM.col_num; - MPI_Allgather(&snd[coord[1]],1,MPI_INT,sends,1,MPI_INT,comm_row); - - // fpt : start column index after applied 'mat_2d' reorder algorithms - // to matrix. - fpt[0] = 0; - for (int i=1; i **cc,std::complex *Z, const int &ik) -{ - ModuleBase::TITLE("Pdiag_Double","gath_eig_complex"); - time_t time_start = time(NULL); - //GlobalV::ofs_running << " Start gath_eig_complex Time : " << ctime(&time_start); - const Parallel_Orbitals* pv = this->ParaV; - - int i,j,k; - int nprocs,myid; - MPI_Status status; - MPI_Comm_size(comm,&nprocs); - MPI_Comm_rank(comm,&myid); - - std::complex **ctot; - - // mohan add 2010-07-03 - // the occupied bands are useless - // for calculating charge density. - if(GlobalV::DRANK> pv->lastband_in_proc) - { - delete[] Z; - } - - // first we need to collect all - // the occupied bands. - // GlobalV::NBANDS * GlobalV::NLOCAL - if(GlobalV::DRANK==0) - { - ctot = new std::complex*[GlobalV::NBANDS]; - for (int i=0; i[GlobalV::NLOCAL]; - ModuleBase::GlobalFunc::ZEROS(ctot[i], GlobalV::NLOCAL); - } - ModuleBase::Memory::record("Pdiag_Double","ctot",GlobalV::NBANDS*GlobalV::NLOCAL,"cdouble"); - } - - k=0; - if (myid==0) - { - // mohan add nbnd0 2010-07-02 - int nbnd0 = -1; - if (GlobalV::NBANDS < pv->loc_sizes[0]) - { - // means all bands in this processor - // is needed ( is occupied) - nbnd0 = GlobalV::NBANDS; - } - else - { - // means this processor only save - // part of GlobalV::NBANDS. - nbnd0 = pv->loc_sizes[0]; - } - if(pv->testpb)GlobalV::ofs_running << " nbnd in processor 0 is " << nbnd0 << std::endl; - - for (i=0; iloc_sizes[0]+i]; - } - k++; - } - // Z is useless in processor 0 now. - delete[] Z; - } - MPI_Barrier(comm); - - for (i=1; i<= pv->lastband_in_proc; i++) - { - // mohan fix bug 2010-07-02 - // rows indicates the data structure of Z. - // mpi_times indicates the data distribution - // time, each time send a band. - int rows = pv->loc_sizes[i]; - int mpi_times; - if (i==pv->lastband_in_proc) - { - mpi_times = pv->lastband_number; - } - else - { - mpi_times = pv->loc_sizes[i]; - } - if(pv->testpb)GlobalV::ofs_running << " nbnd in processor " << i << " is " << mpi_times << std::endl; - if (myid==i) - { - for (j=0; j *send = new std::complex[n]; - int count = 0; - - for (int m=0; m *ctmp = new std::complex[GlobalV::NLOCAL]; - ModuleBase::GlobalFunc::ZEROS(ctmp, GlobalV::NLOCAL); - int tag = j; - - // Processor 0 receive the data from other processors. - MPI_Recv(ctmp,n,mpicomplex,i,tag,comm,&status); - - for (int m=0; mout_lowf) - { -// std::cout << " write the wave functions" << std::endl; - WF_Local::write_lowf_complex( ss.str(), ctot, ik );//mohan add 2010-09-09 - } - - // mohan add 2010-09-10 - // distribution of local wave functions - // to each processor. - WF_Local::distri_lowf_complex( ctot, cc); - - // clean staff. - if(GlobalV::DRANK==0) - { - for (int i=0; i **c,std::complex *Z) -{ - ModuleBase::TITLE("Pdiag_Double","gath_full_eig_complex"); - - time_t time_start = time(NULL); - //GlobalV::ofs_running << " Start gath_full_eig_complex Time : " << ctime(&time_start); - - int i,j,k,incx=1; - int *loc_sizes,loc_size,nprocs,myid; - MPI_Status status; - MPI_Comm_size(comm,&nprocs); - MPI_Comm_rank(comm,&myid); - loc_sizes =(int*)malloc(sizeof(int)*nprocs); - loc_size=n/nprocs; - for (i=0; iParaV; - - int i, j, k; - int nprocs,myid; - MPI_Status status; - MPI_Comm_size(comm,&nprocs); - MPI_Comm_rank(comm,&myid); - - double **ctot; - - // mohan add 2010-07-03 - // the occupied bands are useless - // for calculating charge density. - if(GlobalV::DRANK > pv->lastband_in_proc) - { - delete[] Z; - } - - // first we need to collect all - // the occupied bands. - // GlobalV::NBANDS * GlobalV::NLOCAL - if(myid==0) - { - ctot = new double*[GlobalV::NBANDS]; - for (int i=0; iloc_sizes[0]) - { - // means all bands in this processor - // is needed ( is occupied) - nbnd0 = GlobalV::NBANDS; - } - else - { - // means this processor only save - // part of GlobalV::NBANDS. - nbnd0 = pv->loc_sizes[0]; - } - if(pv->testpb) GlobalV::ofs_running << " nbnd in processor 0 is " << nbnd0 << std::endl; - -//printf("from 0 to %d\n",nbnd0-1); - for (i=0; iloc_sizes[0]+i]; - } - k++; - } - // Z is useless in processor 0 now. - delete[] Z; - } - MPI_Barrier(comm); - - - for (i=1; i<= pv->lastband_in_proc; i++) - { - // mohan fix bug 2010-07-02 - // rows indicates the data structure of Z. - // mpi_times indicates the data distribution - // time, each time send a band. - int rows = pv->loc_sizes[i]; - int mpi_times; - if (i==pv->lastband_in_proc) - { - mpi_times = pv->lastband_number; - } - else - { - mpi_times = pv->loc_sizes[i]; - } - if(pv->testpb)GlobalV::ofs_running << " nbnd in processor " << i << " is " << mpi_times << std::endl; - if (myid==i) - { - for (j=0; j1); -} -*/ -MPI_Barrier(comm); - - // mohan add 2010-09-10 - // output the wave function if required. - // this is a bad position to output wave functions. - // but it works! - if(this->out_lowf) - { - // read is in ../src_algorithms/wf_local.cpp - std::stringstream ss; - ss << GlobalV::global_out_dir << "LOWF_GAMMA_S" << GlobalV::CURRENT_SPIN+1 << ".dat"; - // mohan add 2012-04-03, because we need the occupations for the - // first iteration. - Occupy::calculate_weights(); - WF_Local::write_lowf( ss.str(), ctot );//mohan add 2010-09-09 - } - - // mohan add 2010-09-10 - // distribution of local wave functions - // to each processor. - // only used for GlobalV::GAMMA_ONLY_LOCAL - //WF_Local::distri_lowf( ctot, wfc); - - - // clean staff. - if(myid==0) - { - for (int i=0; i **cc,std::complex *Z, const int &ik) +{ + ModuleBase::TITLE("Pdiag_Double","gath_eig_complex"); + time_t time_start = time(NULL); + //GlobalV::ofs_running << " Start gath_eig_complex Time : " << ctime(&time_start); + const Parallel_Orbitals* pv = this->ParaV; + + int i,j,k; + int nprocs,myid; + MPI_Status status; + MPI_Comm_size(comm,&nprocs); + MPI_Comm_rank(comm,&myid); + + std::complex **ctot; + + // mohan add 2010-07-03 + // the occupied bands are useless + // for calculating charge density. + if(GlobalV::DRANK> pv->lastband_in_proc) + { + delete[] Z; + } + + // first we need to collect all + // the occupied bands. + // GlobalV::NBANDS * GlobalV::NLOCAL + if(GlobalV::DRANK==0) + { + ctot = new std::complex*[GlobalV::NBANDS]; + for (int i=0; i[GlobalV::NLOCAL]; + ModuleBase::GlobalFunc::ZEROS(ctot[i], GlobalV::NLOCAL); + } + ModuleBase::Memory::record("Pdiag_Double","ctot",GlobalV::NBANDS*GlobalV::NLOCAL,"cdouble"); + } + + k=0; + if (myid==0) + { + // mohan add nbnd0 2010-07-02 + int nbnd0 = -1; + if (GlobalV::NBANDS < pv->loc_sizes[0]) + { + // means all bands in this processor + // is needed ( is occupied) + nbnd0 = GlobalV::NBANDS; + } + else + { + // means this processor only save + // part of GlobalV::NBANDS. + nbnd0 = pv->loc_sizes[0]; + } + if(pv->testpb)GlobalV::ofs_running << " nbnd in processor 0 is " << nbnd0 << std::endl; + + for (i=0; iloc_sizes[0]+i]; + } + k++; + } + // Z is useless in processor 0 now. + delete[] Z; + } + MPI_Barrier(comm); + + for (i=1; i<= pv->lastband_in_proc; i++) + { + // mohan fix bug 2010-07-02 + // rows indicates the data structure of Z. + // mpi_times indicates the data distribution + // time, each time send a band. + int rows = pv->loc_sizes[i]; + int mpi_times; + if (i==pv->lastband_in_proc) + { + mpi_times = pv->lastband_number; + } + else + { + mpi_times = pv->loc_sizes[i]; + } + if(pv->testpb)GlobalV::ofs_running << " nbnd in processor " << i << " is " << mpi_times << std::endl; + if (myid==i) + { + for (j=0; j *send = new std::complex[n]; + int count = 0; + + for (int m=0; m *ctmp = new std::complex[GlobalV::NLOCAL]; + ModuleBase::GlobalFunc::ZEROS(ctmp, GlobalV::NLOCAL); + int tag = j; + + // Processor 0 receive the data from other processors. + MPI_Recv(ctmp,n,mpicomplex,i,tag,comm,&status); + + for (int m=0; mout_lowf) + { +// std::cout << " write the wave functions" << std::endl; + WF_Local::write_lowf_complex( ss.str(), ctot, ik );//mohan add 2010-09-09 + } + + // mohan add 2010-09-10 + // distribution of local wave functions + // to each processor. + WF_Local::distri_lowf_complex( ctot, cc); + + // clean staff. + if(GlobalV::DRANK==0) + { + for (int i=0; i **c,std::complex *Z) +{ + ModuleBase::TITLE("Pdiag_Double","gath_full_eig_complex"); + + time_t time_start = time(NULL); + //GlobalV::ofs_running << " Start gath_full_eig_complex Time : " << ctime(&time_start); + + int i,j,k,incx=1; + int *loc_sizes,loc_size,nprocs,myid; + MPI_Status status; + MPI_Comm_size(comm,&nprocs); + MPI_Comm_rank(comm,&myid); + loc_sizes =(int*)malloc(sizeof(int)*nprocs); + loc_size=n/nprocs; + for (i=0; iParaV; + + int i, j, k; + int nprocs,myid; + MPI_Status status; + MPI_Comm_size(comm,&nprocs); + MPI_Comm_rank(comm,&myid); + + double **ctot; + + // mohan add 2010-07-03 + // the occupied bands are useless + // for calculating charge density. + if(GlobalV::DRANK > pv->lastband_in_proc) + { + delete[] Z; + } + + // first we need to collect all + // the occupied bands. + // GlobalV::NBANDS * GlobalV::NLOCAL + if(myid==0) + { + ctot = new double*[GlobalV::NBANDS]; + for (int i=0; iloc_sizes[0]) + { + // means all bands in this processor + // is needed ( is occupied) + nbnd0 = GlobalV::NBANDS; + } + else + { + // means this processor only save + // part of GlobalV::NBANDS. + nbnd0 = pv->loc_sizes[0]; + } + if(pv->testpb) GlobalV::ofs_running << " nbnd in processor 0 is " << nbnd0 << std::endl; + +//printf("from 0 to %d\n",nbnd0-1); + for (i=0; iloc_sizes[0]+i]; + } + k++; + } + // Z is useless in processor 0 now. + delete[] Z; + } + MPI_Barrier(comm); + + + for (i=1; i<= pv->lastband_in_proc; i++) + { + // mohan fix bug 2010-07-02 + // rows indicates the data structure of Z. + // mpi_times indicates the data distribution + // time, each time send a band. + int rows = pv->loc_sizes[i]; + int mpi_times; + if (i==pv->lastband_in_proc) + { + mpi_times = pv->lastband_number; + } + else + { + mpi_times = pv->loc_sizes[i]; + } + if(pv->testpb)GlobalV::ofs_running << " nbnd in processor " << i << " is " << mpi_times << std::endl; + if (myid==i) + { + for (j=0; j1); +} +*/ +MPI_Barrier(comm); + + // mohan add 2010-09-10 + // output the wave function if required. + // this is a bad position to output wave functions. + // but it works! + if(this->out_lowf) + { + // read is in ../src_algorithms/wf_local.cpp + std::stringstream ss; + ss << GlobalV::global_out_dir << "LOWF_GAMMA_S" << GlobalV::CURRENT_SPIN+1 << ".dat"; + // mohan add 2012-04-03, because we need the occupations for the + // first iteration. + Occupy::calculate_weights(); + WF_Local::write_lowf( ss.str(), ctot );//mohan add 2010-09-09 + } + + // mohan add 2010-09-10 + // distribution of local wave functions + // to each processor. + // only used for GlobalV::GAMMA_ONLY_LOCAL + //WF_Local::distri_lowf( ctot, wfc); + + + // clean staff. + if(myid==0) + { + for (int i=0; i *b1, complex *b2, complex ci_tpi = ModuleBase::NEG_IMAG_UNIT * ModuleBase::TWO_PI; ModuleBase::Bspline bsp; bsp.init(norder, 1, 0); - bsp.getbslpine(1.0); + bsp.getbspline(1.0); for(int ix = 0 ; ix < this->nx ; ++ix) { complex fracx=0; @@ -141,4 +141,4 @@ void PW_Basis:: bsplinecoef(complex *b1, complex *b2, complexstep_ == 0) { - verlet->setup(); + verlet->setup(p_ensolver); } else { @@ -113,7 +105,7 @@ void Run_MD_PW::md_ions_pw(void) GlobalC::wf.wfcinit(); // update force and virial due to the update of atom positions - MD_func::force_virial(verlet->step_, verlet->mdp, verlet->ucell, verlet->potential, verlet->force, verlet->virial); + MD_func::force_virial(p_ensolver, verlet->step_, verlet->mdp, verlet->ucell, verlet->potential, verlet->force, verlet->virial); verlet->second_half(); @@ -159,83 +151,14 @@ void Run_MD_PW::md_ions_pw(void) return; } -void Run_MD_PW::md_cells_pw() +void Run_MD_PW::md_cells_pw(ModuleEnSover::En_Solver *p_ensolver) { ModuleBase::TITLE("Run_MD_PW", "md_cells_pw"); ModuleBase::timer::tick("Run_MD_PW", "md_cells_pw"); - GlobalC::wf.allocate(GlobalC::kv.nks); - - GlobalC::UFFT.allocate(); - - //======================= - // init pseudopotential - //======================= - GlobalC::ppcell.init(GlobalC::ucell.ntype); - - //===================== - // init hamiltonian - // only allocate in the beginning of ELEC LOOP! - //===================== - GlobalC::hm.hpw.allocate(GlobalC::wf.npwx, GlobalV::NPOL, GlobalC::ppcell.nkb, GlobalC::pw.nrxx); - - //================================= - // initalize local pseudopotential - //================================= - GlobalC::ppcell.init_vloc(GlobalC::pw.nggm, GlobalC::ppcell.vloc); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "LOCAL POTENTIAL"); - - //====================================== - // Initalize non local pseudopotential - //====================================== - GlobalC::ppcell.init_vnl(GlobalC::ucell); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL"); - - //========================================================= - // calculate the total local pseudopotential in real space - //========================================================= - GlobalC::pot.init_pot(0, GlobalC::pw.strucFac); //atomic_rho, v_of_rho, set_vrs - - GlobalC::pot.newd(); - - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT POTENTIAL"); - - //================================================== - // create GlobalC::ppcell.tab_at , for trial wave functions. - //================================================== - GlobalC::wf.init_at_1(); - - //================================ - // Initial start wave functions - //================================ - if (GlobalV::NBANDS != 0 ) // liuyu update 2021-12-10 - { - GlobalC::wf.wfcinit(); - } -#ifdef __LCAO -#ifdef __MPI - switch (GlobalC::exx_global.info.hybrid_type) // Peize Lin add 2019-03-09 - { - case Exx_Global::Hybrid_Type::HF: - case Exx_Global::Hybrid_Type::PBE0: - case Exx_Global::Hybrid_Type::HSE: - GlobalC::exx_lip.init(&GlobalC::kv, &GlobalC::wf, &GlobalC::pw, &GlobalC::UFFT, &GlobalC::ucell); - break; - case Exx_Global::Hybrid_Type::No: - break; - case Exx_Global::Hybrid_Type::Generate_Matrix: - default: - throw std::invalid_argument(ModuleBase::GlobalFunc::TO_STRING(__FILE__) + ModuleBase::GlobalFunc::TO_STRING(__LINE__)); - } -#endif -#endif - - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); - // ion optimization begins // electron density optimization is included in ion optimization - - this->md_ions_pw(); + this->md_ions_pw(p_ensolver); GlobalV::ofs_running << "\n\n --------------------------------------------" << std::endl; GlobalV::ofs_running << std::setprecision(16); @@ -246,6 +169,7 @@ void Run_MD_PW::md_cells_pw() } void Run_MD_PW::md_force_virial( + ModuleEnSover::En_Solver *p_ensolver, const int &istep, const int& numIon, double &potential, @@ -333,8 +257,9 @@ void Run_MD_PW::md_force_virial( } ModuleBase::matrix fcs; - Forces ff; - ff.init(fcs); + // Forces ff; + // ff.init(fcs); + p_ensolver->cal_Force(fcs); for(int ion=0;ioncal_Stress(virial); virial = 0.5 * virial; } diff --git a/source/src_pw/run_md_pw.h b/source/src_pw/run_md_pw.h index 80f3d5ccbab..693f25e217e 100644 --- a/source/src_pw/run_md_pw.h +++ b/source/src_pw/run_md_pw.h @@ -3,6 +3,7 @@ #define RUN_MD_PW_H #include "../src_pw/charge_extra.h" +#include "module_ensolver/en_solver.h" class Run_MD_PW { @@ -10,9 +11,11 @@ class Run_MD_PW Run_MD_PW(); ~Run_MD_PW(); - void md_ions_pw(); - void md_cells_pw(); - void md_force_virial(const int &istep, + void md_ions_pw(ModuleEnSover::En_Solver *p_ensolver); + void md_cells_pw(ModuleEnSover::En_Solver *p_ensolver); + void md_force_virial( + ModuleEnSover::En_Solver *p_ensolver, + const int &istep, const int& numIon, double &potential, ModuleBase::Vector3* force, diff --git a/source/src_ri/exx_abfs-parallel-communicate-dm3.h b/source/src_ri/exx_abfs-parallel-communicate-dm3.h index 43e30dd9f58..98851229b15 100644 --- a/source/src_ri/exx_abfs-parallel-communicate-dm3.h +++ b/source/src_ri/exx_abfs-parallel-communicate-dm3.h @@ -4,7 +4,7 @@ #include "exx_abfs-parallel.h" #include "abfs-vector3_order.h" #include "../module_base/complexmatrix.h" -#include "src_parallel/parallel_orbitals.h" +#include "module_orbital/parallel_orbitals.h" #include "src_lcao/local_orbital_charge.h" #ifdef __MPI #include "mpi.h" diff --git a/source/src_ri/exx_abfs-parallel-communicate-function.h b/source/src_ri/exx_abfs-parallel-communicate-function.h index 215cf095e2b..0db2427fc61 100644 --- a/source/src_ri/exx_abfs-parallel-communicate-function.h +++ b/source/src_ri/exx_abfs-parallel-communicate-function.h @@ -3,7 +3,7 @@ #define EXX_ABFS_PARALLEL_COMMUNICATE_FUNCTION_H #include "exx_abfs-parallel.h" -#include "src_parallel/parallel_orbitals.h" +#include "module_orbital/parallel_orbitals.h" #include #include #ifdef __MPI diff --git a/source/src_ri/exx_abfs-parallel-communicate-hexx.h b/source/src_ri/exx_abfs-parallel-communicate-hexx.h index f33bb4e8e45..aca55569d4b 100644 --- a/source/src_ri/exx_abfs-parallel-communicate-hexx.h +++ b/source/src_ri/exx_abfs-parallel-communicate-hexx.h @@ -13,7 +13,7 @@ #include #include "mpi.h" #include -#include "src_parallel/parallel_orbitals.h" +#include "module_orbital/parallel_orbitals.h" // mohan comment out 2021-02-06 //#include diff --git a/tests/integrate/260_NO_15_PK_PU_AF/INPUT b/tests/integrate/260_NO_15_PK_PU_AF/INPUT index e18597a2876..3022e1417a5 100644 --- a/tests/integrate/260_NO_15_PK_PU_AF/INPUT +++ b/tests/integrate/260_NO_15_PK_PU_AF/INPUT @@ -31,7 +31,6 @@ basis_type lcao gamma_only 1 symmetry 0 nspin 2 -newdm 1 #Parameter DFT+U dft_plus_u 1 diff --git a/tests/module_orb/src/ORB_unittest.cpp b/tests/module_orb/src/ORB_unittest.cpp deleted file mode 100644 index e5e008b8bf2..00000000000 --- a/tests/module_orb/src/ORB_unittest.cpp +++ /dev/null @@ -1,131 +0,0 @@ -#include "ORB_unittest.h" -#include -#include -#include - -test_orb::test_orb() -{} - -test_orb::~test_orb() -{} - -void test_orb::set_ekcut() -{ - std::cout << "set lcao_ecut from LCAO files" << std::endl; - //set as max of ekcut from every element - - lcao_ecut=0.0; - std::ifstream in_ao; - - for(int it=0;it> word; - if(word == "Cutoff(Ry)") break; - } - in_ao >> ek_current; - lcao_ecut = std::max(lcao_ecut,ek_current); - - in_ao.close(); - } - - ORB.ecutwfc=lcao_ecut; - cout << "lcao_ecut : " << lcao_ecut << std::endl; - - return; -} - -void test_orb::set_orbs(const double &lat0_in) -{ - for(int it=0;it> ofile; - ORB.orbital_file.push_back(ofile); - - std::cout << "Numerical orbital file : " << ofile << std::endl; - } - - return; -} - -void test_orb::count_ntype() -{ - std::cout << "count number of atom types" << std::endl; - std::ifstream ifs("STRU",std::ios::in); - - if (!ifs) - { - std::cout << "ERROR : file STRU does not exist" < -#include -#include -using namespace std; - -class test_orb -{ -public: - - test_orb(); - ~test_orb(); - - LCAO_Orbitals ORB; - ORB_gen_tables OGT; - ORB_control ooo; - - std::ofstream ofs_running; - - void count_ntype(); //from STRU, count types of elements - void set_files(); //from STRU, read names of LCAO files - void set_ekcut(); //from LCAO files, read and set ekcut - void set_orbs(const double &lat0_in); //interface to Read_PAO - - bool force_flag = 0; - int my_rank = 0; - int ntype; - - double lcao_ecut = 0; // (Ry) - double lcao_dk = 0.01; - double lcao_dr = 0.01; - double lcao_rmax = 30; // (a.u.) - - int out_descriptor = 0; - int out_r_matrix = 0; - - int lmax=1; - double lat0 = 1.0; -}; diff --git a/tests/module_orb/src/main.cpp b/tests/module_orb/src/main.cpp deleted file mode 100644 index e05342eb71c..00000000000 --- a/tests/module_orb/src/main.cpp +++ /dev/null @@ -1,37 +0,0 @@ -//#include "timer.h" -#include -#include -#include "ORB_unittest.h" -#include "../../../source/module_base/global_variable.h" -#include "../../../source/module_base/global_file.h" - -void calculate(); - -int main(int argc, char **argv) -{ - - std::cout << "Hello, this is the ORB module of ABACUS." << std::endl; - - calculate(); - - return 0; -} - -void calculate() -{ - GlobalV::BASIS_TYPE = "lcao"; - - test_orb test; - - test.ofs_running.open("log.txt"); - test.count_ntype(); - test.set_files(); - test.set_ekcut(); - test.set_orbs(test.lat0); - - std::cout << "--------------------" << std::endl; - std::cout << " Have a great day! " << std::endl; - std::cout << "--------------------" << std::endl; - - return; -}