Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: add UnitTests and new functions and OverlapNew operator with HContainer #2777

Merged
merged 10 commits into from
Aug 2, 2023
39 changes: 7 additions & 32 deletions source/module_basis/module_ao/ORB_gen_tables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ ORB_gen_tables UOT;
ORB_gen_tables::ORB_gen_tables() {}
ORB_gen_tables::~ORB_gen_tables() {}

const ORB_gen_tables& ORB_gen_tables::get_const_instance()
{
return GlobalC::UOT;
}

/// call in hamilt_linear::init_before_ions.
void ORB_gen_tables::gen_tables(
std::ofstream &ofs_in,
Expand Down Expand Up @@ -448,8 +453,6 @@ void ORB_gen_tables::snap_psipsi(
const int &L2,
const int &m2,
const int &N2,
const int &nspin,
std::complex<double> *olm1,
bool cal_syns,
double dmax) const
{
Expand Down Expand Up @@ -614,21 +617,7 @@ void ORB_gen_tables::snap_psipsi(
{
case 0: // calculate overlap.
{
if (nspin != 4)
{
olm[0] += tmpOlm0 * rly[MGT.get_lm_index(L, m)];
}
else if (olm1 != NULL)
{
olm1[0] += tmpOlm0 * rly[MGT.get_lm_index(L, m)];
olm1[1] += 0; //tmpOlm0 * (tmp(0,0)+tmp(0,1));
olm1[2] += 0; //tmpOlm0 * (tmp(1,0)+tmp(1,1));
olm1[3] += tmpOlm0 * rly[MGT.get_lm_index(L, m)];
}
else
{
ModuleBase::WARNING_QUIT("ORB_gen_tables::snap_psipsi", "something wrong!");
}
olm[0] += tmpOlm0 * rly[MGT.get_lm_index(L, m)];

/*
if( abs ( tmpOlm0 * rly[ MGT.get_lm_index(L, m) ] ) > 1.0e-3 )
Expand Down Expand Up @@ -710,21 +699,7 @@ void ORB_gen_tables::snap_psipsi(
{
case 0:
{
if (nspin != 4)
{
olm[0] += tmpKem0 * rly[MGT.get_lm_index(L, m)];
}
else if (olm1 != NULL)
{
olm1[0] += tmpKem0 * rly[MGT.get_lm_index(L, m)];
olm1[1] += 0; //tmpKem0 * (tmp(0,0)+tmp(0,1));
olm1[2] += 0; //tmpKem0 * (tmp(1,0)+tmp(1,1));
olm1[3] += tmpKem0 * rly[MGT.get_lm_index(L, m)];
}
else
{
ModuleBase::WARNING_QUIT("ORB_gen_tables::snap_psipsi", "something wrong in T.");
}
olm[0] += tmpKem0 * rly[MGT.get_lm_index(L, m)];
break;
}
case 1:
Expand Down
5 changes: 3 additions & 2 deletions source/module_basis/module_ao/ORB_gen_tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class ORB_gen_tables
ORB_gen_tables();
~ORB_gen_tables();

// static function to get global instance
static const ORB_gen_tables& get_const_instance();

void gen_tables(
std::ofstream &ofs_in, // mohan add 2021-05-07
LCAO_Orbitals &orb,
Expand All @@ -49,8 +52,6 @@ class ORB_gen_tables
const int &l2,
const int &m2,
const int &n2,
const int &nspin,
std::complex<double> *olm1=NULL,
bool cal_syns = false,
double dmax = 0.0)const;

Expand Down
5 changes: 5 additions & 0 deletions source/module_basis/module_ao/ORB_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ LCAO_Orbitals::~LCAO_Orbitals()
delete[] Alpha;
}

const LCAO_Orbitals& LCAO_Orbitals::get_const_instance()
{
return GlobalC::ORB;
}

#ifdef __MPI
// be called in UnitCell.
void LCAO_Orbitals::bcast_files(
Expand Down
3 changes: 3 additions & 0 deletions source/module_basis/module_ao/ORB_read.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class LCAO_Orbitals
LCAO_Orbitals();
~LCAO_Orbitals();

// static function to get global instance
static const LCAO_Orbitals& get_const_instance();

void Read_Orbitals(
std::ofstream &ofs_in, // mohan add 2021-05-07
const int &ntype_in,
Expand Down
100 changes: 88 additions & 12 deletions source/module_basis/module_ao/parallel_orbitals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,41 @@ Parallel_Orbitals::~Parallel_Orbitals()

void Parallel_Orbitals::set_atomic_trace(const int* iat2iwt, const int &nat, const int &nlocal)
{
this->atom_begin_col.resize(nat);
this->atom_begin_row.resize(nat);
for(int iat=0;iat<nat-1;iat++)
ModuleBase::TITLE("Parallel_Orbitals", "set_atomic_trace");
this->iat2iwt_ = iat2iwt;
int nat_plus_1 = nat + 1;
this->atom_begin_col.resize(nat_plus_1);
this->atom_begin_row.resize(nat_plus_1);
for(int iat=0;iat<nat;iat++)
{
this->atom_begin_col[iat] = -1;
this->atom_begin_row[iat] = -1;
int irow = iat2iwt[iat];
int icol = iat2iwt[iat];
const int max = (iat == nat-1) ? (nlocal - irow): (iat2iwt[iat+1] - irow);
//find the first row index of atom iat
for(int i=0;i<max;i++)
const int nw_global = (iat == nat-1) ? (nlocal - irow): (iat2iwt[iat+1] - irow);
//find the first local row index of atom iat
for(int i=0;i<nw_global;i++)
{
if (this->global2local_row_[irow] != -1)
{
this->atom_begin_row[iat] = irow;
this->atom_begin_row[iat] = this->global2local_row_[irow];
break;
}
irow++;
}
//find the first col index of atom iat
for(int i=0;i<max;i++)
//find the first local col index of atom iat
for(int i=0;i<nw_global;i++)
{
if (this->global2local_col_[icol] != -1)
{
this->atom_begin_col[iat] = icol;
this->atom_begin_col[iat] = this->global2local_col_[icol];
break;
}
icol++;
}
}
this->atom_begin_row[nat] = this->nrow;
this->atom_begin_col[nat] = this->ncol;
}

// Get the number of columns of the parallel orbital matrix
Expand Down Expand Up @@ -103,7 +108,7 @@ int Parallel_Orbitals::get_row_size(int iat) const
return 0;
}
iat += 1;
hongriTianqi marked this conversation as resolved.
Show resolved Hide resolved
while(this->atom_begin_row[iat] <= this->ncol)
while(this->atom_begin_row[iat] <= this->nrow)
{
if(this->atom_begin_row[iat] != -1)
{
Expand All @@ -116,6 +121,77 @@ int Parallel_Orbitals::get_row_size(int iat) const
throw std::string("error in get_col_size(iat)");
}

// Get the global indexes of the rows of the parallel orbital matrix
std::vector<int> Parallel_Orbitals::get_indexes_row() const
{
std::vector<int> indexes(this->nrow);
for(int i = 0; i < this->nrow; i++)
{
#ifdef __MPI
indexes[i] = this->local2global_row(i);
#else
indexes[i] = i;
#endif
}
return indexes;
}
// Get the global indexes of the columns of the parallel orbital matrix
std::vector<int> Parallel_Orbitals::get_indexes_col() const
{
std::vector<int> indexes(this->ncol);
for(int i = 0; i < this->ncol; i++)
{
#ifdef __MPI
indexes[i] = this->local2global_col(i);
#else
indexes[i] = i;
#endif
}
return indexes;
}
// Get the global indexes of the rows of the orbital matrix of the iat-th atom
std::vector<int> Parallel_Orbitals::get_indexes_row(int iat) const
{
int size = this->get_row_size(iat);
if(size == 0)
{
return std::vector<int>();
}
std::vector<int> indexes(size);
int irow = this->atom_begin_row[iat];
int begin = this->iat2iwt_[iat];
for(int i = 0; i < size; ++i)
{
#ifdef __MPI
indexes[i] = this->local2global_row(irow + i) - begin;
#else
indexes[i] = i;
#endif
}
return indexes;
}
// Get the global indexes of the columns of the orbital matrix of the iat-th atom
std::vector<int> Parallel_Orbitals::get_indexes_col(int iat) const
{
int size = this->get_col_size(iat);
if(size == 0)
{
return std::vector<int>();
}
std::vector<int> indexes(size);
int icol = this->atom_begin_col[iat];
int begin = this->iat2iwt_[iat];
for(int i = 0; i < size; ++i)
{
#ifdef __MPI
indexes[i] = this->local2global_col(icol + i) - begin;
#else
indexes[i] = i;
#endif
}
return indexes;
}

#ifdef __MPI
void Parallel_Orbitals::set_desc_wfc_Eij(const int& nbasis, const int& nbands, const int& lld)
{
Expand Down Expand Up @@ -184,4 +260,4 @@ int Parallel_Orbitals::set_nloc_wfc_Eij(

return 0;
}
#endif
#endif
24 changes: 23 additions & 1 deletion source/module_basis/module_ao/parallel_orbitals.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,15 @@ class Parallel_Orbitals : public Parallel_2D
int* loc_sizes;
int loc_size;

// set row and col begin index for each atom
/**
* @brief set row and col begin index for each atom
* it should be called after:
* 1. nrow and ncol are set;
* 2. global2local_row_ and global2local_col_ are set;
* @param iat2iwt : the map from atom index to global oribtal indexes
* @param nat : number of atoms
* @param nlocal : number of global orbitals
*/
void set_atomic_trace(const int* iat2iwt, const int &nat, const int &nlocal);

/**
Expand All @@ -67,10 +75,24 @@ class Parallel_Orbitals : public Parallel_2D
int get_col_size(int iat) const;
int get_row_size(int iat) const;

/**
* @brief gather global indexes of orbitals in this processor
* get_indexes_row() : global indexes (~NLOCAL) of rows of Hamiltonian matrix in this processor
* get_indexes_col() : global indexes (~NLOCAL) of columns of Hamiltonian matrix in this processor
* get_indexes_row(iat) : global indexes (~nw) of rows of Hamiltonian matrix in atom iat
* get_indexes_col(iat) : global indexes (~nw) of columns of Hamiltonian matrix in atom iat
*/
std::vector<int> get_indexes_row() const;
std::vector<int> get_indexes_col() const;
std::vector<int> get_indexes_row(int iat) const;
std::vector<int> get_indexes_col(int iat) const;

// private:
// orbital index for each atom
std::vector<int> atom_begin_row;
std::vector<int> atom_begin_col;

const int* iat2iwt_ = nullptr;

};
#endif
8 changes: 4 additions & 4 deletions source/module_basis/module_ao/test/1_snap_equal_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ TEST_F(test_orb, equal_test)
OGT.snap_psipsi(
ORB, olm_0, 0, 'S',
R1, T1, L1, m1, N1,
R2, T2, L2, m2, N2,
1, NULL);
R2, T2, L2, m2, N2
);
OGT.snap_psipsi(
ORB, olm_1, 1, 'S',
R1, T1, L1, m1, N1,
R2, T2, L2, m2, N2,
1, NULL);
R2, T2, L2, m2, N2
);
//std::cout << this->mock_center2_orb11[T1][T2][L1][N1][L2][N2]->cal_overlap(R1, R2, m1, m2);
clm_0 =
test_center2_orb11[T1][T2][L1][N1][L2][N2]->cal_overlap(R1, R2, m1, m2);
Expand Down
6 changes: 6 additions & 0 deletions source/module_basis/module_ao/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ AddTest(
LIBS ${math_libs} device base
)

AddTest(
TARGET parallel_orbitals_test
SOURCES parallel_orbitals_test.cpp ../parallel_2d.cpp ../parallel_orbitals.cpp
LIBS ${math_libs} device base
)

install(DIRECTORY lcao_H2O DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
install(DIRECTORY lcao_H2O DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/../../../tests)

Loading