Skip to content

Commit

Permalink
Feature & Refactor: read and write Hexx(R) in CSR format (#3727)
Browse files Browse the repository at this point in the history
* refactor sparse output

* read HexxR in CSR

* parallel support

* remove test code

* comments

* recover multiple process

* refactor singleR

* change func names
  • Loading branch information
maki49 committed Apr 2, 2024
1 parent c78456c commit dc7938b
Show file tree
Hide file tree
Showing 12 changed files with 320 additions and 457 deletions.
6 changes: 3 additions & 3 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,14 +1039,14 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
#ifdef __EXX
if (GlobalC::exx_info.info_global.cal_exx) // Peize Lin add if 2022.11.14
{
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR_" + std::to_string(GlobalV::MY_RANK);
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
if (GlobalC::exx_info.info_ri.real_number)
{
this->exd->write_Hexxs(file_name_exx);
this->exd->write_Hexxs_csr(file_name_exx, GlobalC::ucell);
}
else
{
this->exc->write_Hexxs(file_name_exx);
this->exc->write_Hexxs_csr(file_name_exx, GlobalC::ucell);
}
}
#endif
Expand Down
6 changes: 3 additions & 3 deletions source/module_esolver/esolver_ks_lcao_elec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,11 +529,11 @@ void ESolver_KS_LCAO<TK, TR>::nscf()
if (GlobalC::exx_info.info_global.cal_exx)
{
// GlobalC::exx_lcao.cal_exx_elec_nscf(this->LOWF.ParaV[0]);
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR_" + std::to_string(GlobalV::MY_RANK);
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
if (GlobalC::exx_info.info_ri.real_number)
this->exd->read_Hexxs(file_name_exx);
this->exd->read_Hexxs_csr(file_name_exx, GlobalC::ucell);
else
this->exc->read_Hexxs(file_name_exx);
this->exc->read_Hexxs_csr(file_name_exx, GlobalC::ucell);

hamilt::HamiltLCAO<TK, TR>* hamilt_lcao = dynamic_cast<hamilt::HamiltLCAO<TK, TR>*>(this->p_hamilt);
auto exx = new hamilt::OperatorEXX<hamilt::OperatorLCAO<TK, TR>>(&this->LM,
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/csr_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,6 @@ int csrFileReader<T>::getStep() const
// T of AtomPair can be double
template class csrFileReader<double>;
// ToDo: T of AtomPair can be std::complex<double>
// template class csrFileReader<std::complex<double>>;
template class csrFileReader<std::complex<double>>;

} // namespace ModuleIO
161 changes: 37 additions & 124 deletions source/module_io/single_R_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,132 +3,35 @@
#include "module_base/global_function.h"
#include "module_base/global_variable.h"

void ModuleIO::output_single_R(std::ofstream &ofs, const std::map<size_t, std::map<size_t, double>> &XR, const double &sparse_threshold, const bool &binary, const Parallel_Orbitals &pv)
inline void write_data(std::ofstream& ofs, const double& data)
{
double *line = nullptr;
std::vector<int> indptr;
indptr.reserve(GlobalV::NLOCAL + 1);
indptr.push_back(0);

std::stringstream tem1;
tem1 << GlobalV::global_out_dir << "temp_sparse_indices.dat";
std::ofstream ofs_tem1;
std::ifstream ifs_tem1;

if (GlobalV::DRANK == 0)
{
if (binary)
{
ofs_tem1.open(tem1.str().c_str(), std::ios::binary);
}
else
{
ofs_tem1.open(tem1.str().c_str());
}
}

line = new double[GlobalV::NLOCAL];
for(int row = 0; row < GlobalV::NLOCAL; ++row)
{
// line = new double[GlobalV::NLOCAL];
ModuleBase::GlobalFunc::ZEROS(line, GlobalV::NLOCAL);

if(pv.global2local_row(row) >= 0)
{
auto iter = XR.find(row);
if (iter != XR.end())
{
for (auto &value : iter->second)
{
line[value.first] = value.second;
}
}
}

Parallel_Reduce::reduce_all(line, GlobalV::NLOCAL);

if(GlobalV::DRANK == 0)
{
int nonzeros_count = 0;
for (int col = 0; col < GlobalV::NLOCAL; ++col)
{
if (std::abs(line[col]) > sparse_threshold)
{
if (binary)
{
ofs.write(reinterpret_cast<char *>(&line[col]), sizeof(double));
ofs_tem1.write(reinterpret_cast<char *>(&col), sizeof(int));
}
else
{
ofs << " " << std::fixed << std::scientific << std::setprecision(8) << line[col];
ofs_tem1 << " " << col;
}

nonzeros_count++;

}

}
nonzeros_count += indptr.back();
indptr.push_back(nonzeros_count);
}

// delete[] line;
// line = nullptr;

}

delete[] line;
line = nullptr;

if (GlobalV::DRANK == 0)
{
if (binary)
{
ofs_tem1.close();
ifs_tem1.open(tem1.str().c_str(), std::ios::binary);
ofs << ifs_tem1.rdbuf();
ifs_tem1.close();
for (auto &i : indptr)
{
ofs.write(reinterpret_cast<char *>(&i), sizeof(int));
}
}
else
{
ofs << std::endl;
ofs_tem1 << std::endl;
ofs_tem1.close();
ifs_tem1.open(tem1.str().c_str());
ofs << ifs_tem1.rdbuf();
ifs_tem1.close();
for (auto &i : indptr)
{
ofs << " " << i;
}
ofs << std::endl;
}

std::remove(tem1.str().c_str());

}

ofs << " " << std::fixed << std::scientific << std::setprecision(8) << data;
}
inline void write_data(std::ofstream& ofs, const std::complex<double>& data)
{
ofs << " (" << std::fixed << std::scientific << std::setprecision(8) << data.real() << ","
<< std::fixed << std::scientific << std::setprecision(8) << data.imag() << ")";
}

void ModuleIO::output_soc_single_R(std::ofstream &ofs, const std::map<size_t, std::map<size_t, std::complex<double>>> &XR, const double &sparse_threshold, const bool &binary, const Parallel_Orbitals &pv)
template<typename T>
void ModuleIO::output_single_R(std::ofstream& ofs,
const std::map<size_t, std::map<size_t, T>>& XR,
const double& sparse_threshold,
const bool& binary,
const Parallel_Orbitals& pv,
const bool& reduce)
{
std::complex<double> *line = nullptr;
T* line = nullptr;
std::vector<int> indptr;
indptr.reserve(GlobalV::NLOCAL + 1);
indptr.push_back(0);

std::stringstream tem1;
tem1 << GlobalV::global_out_dir << "temp_sparse_indices.dat";
tem1 << GlobalV::global_out_dir << std::to_string(GlobalV::DRANK) + "temp_sparse_indices.dat";
std::ofstream ofs_tem1;
std::ifstream ifs_tem1;

if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
if (binary)
{
Expand All @@ -140,13 +43,12 @@ void ModuleIO::output_soc_single_R(std::ofstream &ofs, const std::map<size_t, st
}
}

line = new std::complex<double>[GlobalV::NLOCAL];
line = new T[GlobalV::NLOCAL];
for(int row = 0; row < GlobalV::NLOCAL; ++row)
{
// line = new std::complex<double>[GlobalV::NLOCAL];
ModuleBase::GlobalFunc::ZEROS(line, GlobalV::NLOCAL);

if(pv.global2local_row(row) >= 0)
if (!reduce || pv.global2local_row(row) >= 0)
{
auto iter = XR.find(row);
if (iter != XR.end())
Expand All @@ -158,9 +60,9 @@ void ModuleIO::output_soc_single_R(std::ofstream &ofs, const std::map<size_t, st
}
}

Parallel_Reduce::reduce_all(line, GlobalV::NLOCAL);
if (reduce)Parallel_Reduce::reduce_all(line, GlobalV::NLOCAL);

if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
int nonzeros_count = 0;
for (int col = 0; col < GlobalV::NLOCAL; ++col)
Expand All @@ -169,13 +71,12 @@ void ModuleIO::output_soc_single_R(std::ofstream &ofs, const std::map<size_t, st
{
if (binary)
{
ofs.write(reinterpret_cast<char *>(&line[col]), sizeof(std::complex<double>));
ofs.write(reinterpret_cast<char*>(&line[col]), sizeof(T));
ofs_tem1.write(reinterpret_cast<char *>(&col), sizeof(int));
}
else
{
ofs << " (" << std::fixed << std::scientific << std::setprecision(8) << line[col].real() << ","
<< std::fixed << std::scientific << std::setprecision(8) << line[col].imag() << ")";
write_data(ofs, line[col]);
ofs_tem1 << " " << col;
}

Expand All @@ -196,7 +97,7 @@ void ModuleIO::output_soc_single_R(std::ofstream &ofs, const std::map<size_t, st
delete[] line;
line = nullptr;

if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
if (binary)
{
Expand Down Expand Up @@ -226,5 +127,17 @@ void ModuleIO::output_soc_single_R(std::ofstream &ofs, const std::map<size_t, st

std::remove(tem1.str().c_str());
}

}

template void ModuleIO::output_single_R<double>(std::ofstream& ofs,
const std::map<size_t, std::map<size_t, double>>& XR,
const double& sparse_threshold,
const bool& binary,
const Parallel_Orbitals& pv,
const bool& reduce);
template void ModuleIO::output_single_R<std::complex<double>>(std::ofstream& ofs,
const std::map<size_t, std::map<size_t, std::complex<double>>>& XR,
const double& sparse_threshold,
const bool& binary,
const Parallel_Orbitals& pv,
const bool& reduce);
9 changes: 7 additions & 2 deletions source/module_io/single_R_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

namespace ModuleIO
{
void output_single_R(std::ofstream &ofs, const std::map<size_t, std::map<size_t, double>> &XR, const double &sparse_threshold, const bool &binary, const Parallel_Orbitals &pv);
void output_soc_single_R(std::ofstream &ofs, const std::map<size_t, std::map<size_t, std::complex<double>>> &XR, const double &sparse_threshold, const bool &binary, const Parallel_Orbitals &pv);
template <typename T>
void output_single_R(std::ofstream& ofs,
const std::map<size_t, std::map<size_t, T>>& XR,
const double& sparse_threshold,
const bool& binary,
const Parallel_Orbitals& pv,
const bool& reduce = true);
}

#endif
2 changes: 1 addition & 1 deletion source/module_io/sparse_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void SparseMatrix<T>::readCSR(const std::vector<T>& values,

// define the operator to index a matrix element
template <typename T>
T SparseMatrix<T>::operator()(int row, int col)
T SparseMatrix<T>::operator()(int row, int col) const
{
if (row < 0 || row >= _rows || col < 0 || col >= _cols)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/sparse_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class SparseMatrix
}

// define the operator to index a matrix element
T operator()(int row, int col);
T operator()(int row, int col)const;

// set the threshold
void setSparseThreshold(double sparse_threshold)
Expand Down
4 changes: 2 additions & 2 deletions source/module_io/write_HS_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ void ModuleIO::output_S_R(
ModuleBase::timer::tick("ModuleIO","output_S_R");

UHM.cal_SR_sparse(sparse_threshold, p_ham);
ModuleIO::save_SR_sparse(*UHM.LM, sparse_threshold, binary, SR_filename);
ModuleIO::save_sparse(UHM.LM->SR_sparse, UHM.LM->all_R_coor, sparse_threshold, binary, SR_filename, *UHM.LM->ParaV, "S", 0);
UHM.destroy_all_HSR_sparse();

ModuleBase::timer::tick("ModuleIO","output_S_R");
Expand Down Expand Up @@ -154,7 +154,7 @@ void ModuleIO::output_T_R(
}

UHM.cal_TR_sparse(sparse_threshold);
ModuleIO::save_TR_sparse(istep, *UHM.LM, sparse_threshold, binary, sst.str().c_str());
ModuleIO::save_sparse(UHM.LM->TR_sparse, UHM.LM->all_R_coor, sparse_threshold, binary, sst.str().c_str(), *UHM.LM->ParaV, "T", istep);
UHM.destroy_TR_sparse();

ModuleBase::timer::tick("ModuleIO","output_T_R");
Expand Down

0 comments on commit dc7938b

Please sign in to comment.