Skip to content

Commit

Permalink
Refactor: add UnitTests and new functions and OverlapNew operator wit…
Browse files Browse the repository at this point in the history
…h HContainer (#2777)

* update demo1_SR

* Refactor: update demos in module_hcontainer

* Test: add unit test for parallel_orbitals.cpp

* update parallel_orbitals.h with get_indexes functions

* Refactor: small modifies for constructor of OperatorLCAO

* Refactor: update get_pointer() interface in AtomPair

* Refactor: add func_folding and UnitTest

* Feature: add new Operator for overlap of NAOs - OperatorNew

* Fix: bug in CI
  • Loading branch information
dyzheng committed Aug 2, 2023
1 parent ae4e7ea commit add4812
Show file tree
Hide file tree
Showing 37 changed files with 1,246 additions and 238 deletions.
39 changes: 7 additions & 32 deletions source/module_basis/module_ao/ORB_gen_tables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ ORB_gen_tables UOT;
ORB_gen_tables::ORB_gen_tables() {}
ORB_gen_tables::~ORB_gen_tables() {}

const ORB_gen_tables& ORB_gen_tables::get_const_instance()
{
return GlobalC::UOT;
}

/// call in hamilt_linear::init_before_ions.
void ORB_gen_tables::gen_tables(
std::ofstream &ofs_in,
Expand Down Expand Up @@ -448,8 +453,6 @@ void ORB_gen_tables::snap_psipsi(
const int &L2,
const int &m2,
const int &N2,
const int &nspin,
std::complex<double> *olm1,
bool cal_syns,
double dmax) const
{
Expand Down Expand Up @@ -614,21 +617,7 @@ void ORB_gen_tables::snap_psipsi(
{
case 0: // calculate overlap.
{
if (nspin != 4)
{
olm[0] += tmpOlm0 * rly[MGT.get_lm_index(L, m)];
}
else if (olm1 != NULL)
{
olm1[0] += tmpOlm0 * rly[MGT.get_lm_index(L, m)];
olm1[1] += 0; //tmpOlm0 * (tmp(0,0)+tmp(0,1));
olm1[2] += 0; //tmpOlm0 * (tmp(1,0)+tmp(1,1));
olm1[3] += tmpOlm0 * rly[MGT.get_lm_index(L, m)];
}
else
{
ModuleBase::WARNING_QUIT("ORB_gen_tables::snap_psipsi", "something wrong!");
}
olm[0] += tmpOlm0 * rly[MGT.get_lm_index(L, m)];

/*
if( abs ( tmpOlm0 * rly[ MGT.get_lm_index(L, m) ] ) > 1.0e-3 )
Expand Down Expand Up @@ -710,21 +699,7 @@ void ORB_gen_tables::snap_psipsi(
{
case 0:
{
if (nspin != 4)
{
olm[0] += tmpKem0 * rly[MGT.get_lm_index(L, m)];
}
else if (olm1 != NULL)
{
olm1[0] += tmpKem0 * rly[MGT.get_lm_index(L, m)];
olm1[1] += 0; //tmpKem0 * (tmp(0,0)+tmp(0,1));
olm1[2] += 0; //tmpKem0 * (tmp(1,0)+tmp(1,1));
olm1[3] += tmpKem0 * rly[MGT.get_lm_index(L, m)];
}
else
{
ModuleBase::WARNING_QUIT("ORB_gen_tables::snap_psipsi", "something wrong in T.");
}
olm[0] += tmpKem0 * rly[MGT.get_lm_index(L, m)];
break;
}
case 1:
Expand Down
5 changes: 3 additions & 2 deletions source/module_basis/module_ao/ORB_gen_tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class ORB_gen_tables
ORB_gen_tables();
~ORB_gen_tables();

// static function to get global instance
static const ORB_gen_tables& get_const_instance();

void gen_tables(
std::ofstream &ofs_in, // mohan add 2021-05-07
LCAO_Orbitals &orb,
Expand All @@ -49,8 +52,6 @@ class ORB_gen_tables
const int &l2,
const int &m2,
const int &n2,
const int &nspin,
std::complex<double> *olm1=NULL,
bool cal_syns = false,
double dmax = 0.0)const;

Expand Down
5 changes: 5 additions & 0 deletions source/module_basis/module_ao/ORB_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ LCAO_Orbitals::~LCAO_Orbitals()
delete[] Alpha;
}

const LCAO_Orbitals& LCAO_Orbitals::get_const_instance()
{
return GlobalC::ORB;
}

#ifdef __MPI
// be called in UnitCell.
void LCAO_Orbitals::bcast_files(
Expand Down
3 changes: 3 additions & 0 deletions source/module_basis/module_ao/ORB_read.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class LCAO_Orbitals
LCAO_Orbitals();
~LCAO_Orbitals();

// static function to get global instance
static const LCAO_Orbitals& get_const_instance();

void Read_Orbitals(
std::ofstream &ofs_in, // mohan add 2021-05-07
const int &ntype_in,
Expand Down
100 changes: 88 additions & 12 deletions source/module_basis/module_ao/parallel_orbitals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,41 @@ Parallel_Orbitals::~Parallel_Orbitals()

void Parallel_Orbitals::set_atomic_trace(const int* iat2iwt, const int &nat, const int &nlocal)
{
this->atom_begin_col.resize(nat);
this->atom_begin_row.resize(nat);
for(int iat=0;iat<nat-1;iat++)
ModuleBase::TITLE("Parallel_Orbitals", "set_atomic_trace");
this->iat2iwt_ = iat2iwt;
int nat_plus_1 = nat + 1;
this->atom_begin_col.resize(nat_plus_1);
this->atom_begin_row.resize(nat_plus_1);
for(int iat=0;iat<nat;iat++)
{
this->atom_begin_col[iat] = -1;
this->atom_begin_row[iat] = -1;
int irow = iat2iwt[iat];
int icol = iat2iwt[iat];
const int max = (iat == nat-1) ? (nlocal - irow): (iat2iwt[iat+1] - irow);
//find the first row index of atom iat
for(int i=0;i<max;i++)
const int nw_global = (iat == nat-1) ? (nlocal - irow): (iat2iwt[iat+1] - irow);
//find the first local row index of atom iat
for(int i=0;i<nw_global;i++)
{
if (this->global2local_row_[irow] != -1)
{
this->atom_begin_row[iat] = irow;
this->atom_begin_row[iat] = this->global2local_row_[irow];
break;
}
irow++;
}
//find the first col index of atom iat
for(int i=0;i<max;i++)
//find the first local col index of atom iat
for(int i=0;i<nw_global;i++)
{
if (this->global2local_col_[icol] != -1)
{
this->atom_begin_col[iat] = icol;
this->atom_begin_col[iat] = this->global2local_col_[icol];
break;
}
icol++;
}
}
this->atom_begin_row[nat] = this->nrow;
this->atom_begin_col[nat] = this->ncol;
}

// Get the number of columns of the parallel orbital matrix
Expand Down Expand Up @@ -103,7 +108,7 @@ int Parallel_Orbitals::get_row_size(int iat) const
return 0;
}
iat += 1;
while(this->atom_begin_row[iat] <= this->ncol)
while(this->atom_begin_row[iat] <= this->nrow)
{
if(this->atom_begin_row[iat] != -1)
{
Expand All @@ -116,6 +121,77 @@ int Parallel_Orbitals::get_row_size(int iat) const
throw std::string("error in get_col_size(iat)");
}

// Get the global indexes of the rows of the parallel orbital matrix
std::vector<int> Parallel_Orbitals::get_indexes_row() const
{
std::vector<int> indexes(this->nrow);
for(int i = 0; i < this->nrow; i++)
{
#ifdef __MPI
indexes[i] = this->local2global_row(i);
#else
indexes[i] = i;
#endif
}
return indexes;
}
// Get the global indexes of the columns of the parallel orbital matrix
std::vector<int> Parallel_Orbitals::get_indexes_col() const
{
std::vector<int> indexes(this->ncol);
for(int i = 0; i < this->ncol; i++)
{
#ifdef __MPI
indexes[i] = this->local2global_col(i);
#else
indexes[i] = i;
#endif
}
return indexes;
}
// Get the global indexes of the rows of the orbital matrix of the iat-th atom
std::vector<int> Parallel_Orbitals::get_indexes_row(int iat) const
{
int size = this->get_row_size(iat);
if(size == 0)
{
return std::vector<int>();
}
std::vector<int> indexes(size);
int irow = this->atom_begin_row[iat];
int begin = this->iat2iwt_[iat];
for(int i = 0; i < size; ++i)
{
#ifdef __MPI
indexes[i] = this->local2global_row(irow + i) - begin;
#else
indexes[i] = i;
#endif
}
return indexes;
}
// Get the global indexes of the columns of the orbital matrix of the iat-th atom
std::vector<int> Parallel_Orbitals::get_indexes_col(int iat) const
{
int size = this->get_col_size(iat);
if(size == 0)
{
return std::vector<int>();
}
std::vector<int> indexes(size);
int icol = this->atom_begin_col[iat];
int begin = this->iat2iwt_[iat];
for(int i = 0; i < size; ++i)
{
#ifdef __MPI
indexes[i] = this->local2global_col(icol + i) - begin;
#else
indexes[i] = i;
#endif
}
return indexes;
}

#ifdef __MPI
void Parallel_Orbitals::set_desc_wfc_Eij(const int& nbasis, const int& nbands, const int& lld)
{
Expand Down Expand Up @@ -184,4 +260,4 @@ int Parallel_Orbitals::set_nloc_wfc_Eij(

return 0;
}
#endif
#endif
24 changes: 23 additions & 1 deletion source/module_basis/module_ao/parallel_orbitals.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,15 @@ class Parallel_Orbitals : public Parallel_2D
int* loc_sizes;
int loc_size;

// set row and col begin index for each atom
/**
* @brief set row and col begin index for each atom
* it should be called after:
* 1. nrow and ncol are set;
* 2. global2local_row_ and global2local_col_ are set;
* @param iat2iwt : the map from atom index to global oribtal indexes
* @param nat : number of atoms
* @param nlocal : number of global orbitals
*/
void set_atomic_trace(const int* iat2iwt, const int &nat, const int &nlocal);

/**
Expand All @@ -67,10 +75,24 @@ class Parallel_Orbitals : public Parallel_2D
int get_col_size(int iat) const;
int get_row_size(int iat) const;

/**
* @brief gather global indexes of orbitals in this processor
* get_indexes_row() : global indexes (~NLOCAL) of rows of Hamiltonian matrix in this processor
* get_indexes_col() : global indexes (~NLOCAL) of columns of Hamiltonian matrix in this processor
* get_indexes_row(iat) : global indexes (~nw) of rows of Hamiltonian matrix in atom iat
* get_indexes_col(iat) : global indexes (~nw) of columns of Hamiltonian matrix in atom iat
*/
std::vector<int> get_indexes_row() const;
std::vector<int> get_indexes_col() const;
std::vector<int> get_indexes_row(int iat) const;
std::vector<int> get_indexes_col(int iat) const;

// private:
// orbital index for each atom
std::vector<int> atom_begin_row;
std::vector<int> atom_begin_col;

const int* iat2iwt_ = nullptr;

};
#endif
8 changes: 4 additions & 4 deletions source/module_basis/module_ao/test/1_snap_equal_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ TEST_F(test_orb, equal_test)
OGT.snap_psipsi(
ORB, olm_0, 0, 'S',
R1, T1, L1, m1, N1,
R2, T2, L2, m2, N2,
1, NULL);
R2, T2, L2, m2, N2
);
OGT.snap_psipsi(
ORB, olm_1, 1, 'S',
R1, T1, L1, m1, N1,
R2, T2, L2, m2, N2,
1, NULL);
R2, T2, L2, m2, N2
);
//std::cout << this->mock_center2_orb11[T1][T2][L1][N1][L2][N2]->cal_overlap(R1, R2, m1, m2);
clm_0 =
test_center2_orb11[T1][T2][L1][N1][L2][N2]->cal_overlap(R1, R2, m1, m2);
Expand Down
6 changes: 6 additions & 0 deletions source/module_basis/module_ao/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ AddTest(
LIBS ${math_libs} device base
)

AddTest(
TARGET parallel_orbitals_test
SOURCES parallel_orbitals_test.cpp ../parallel_2d.cpp ../parallel_orbitals.cpp
LIBS ${math_libs} device base
)

install(DIRECTORY lcao_H2O DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
install(DIRECTORY lcao_H2O DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/../../../tests)

Loading

0 comments on commit add4812

Please sign in to comment.