Skip to content

Commit

Permalink
Refactor: Removed GlobalVs related to function ModuleIO::write_dm
Browse files Browse the repository at this point in the history
  • Loading branch information
AsTonyshment committed Mar 25, 2024
1 parent e2bdb53 commit 7089f23
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 21 deletions.
5 changes: 4 additions & 1 deletion source/module_io/dm_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ void write_dm(
int out_dm,
double*** DM,
const double& ef,
const UnitCell* ucell);
const UnitCell* ucell,
const int my_rank,
const int nspin,
const int nlocal);

}

Expand Down
5 changes: 4 additions & 1 deletion source/module_io/output_dm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ void Output_DM::write()
_out_dm,
_DM,
_ef,
_ucell);
_ucell,
GlobalV::MY_RANK,
GlobalV::NSPIN,
GlobalV::NLOCAL);
}
} // namespace ModuleIO
2 changes: 1 addition & 1 deletion source/module_io/test_serial/dm_io_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ TEST_F(DMIOTest,Write)
std::string ssd = "SPIN1_DM";
int precision = 3;
int out_dm = 1;
ModuleIO::write_dm(is,0,ssd,precision,out_dm,DM,ef,ucell);
ModuleIO::write_dm(is,0,ssd,precision,out_dm,DM,ef,ucell,GlobalV::MY_RANK,GlobalV::NSPIN,GlobalV::NLOCAL);
std::ifstream ifs;
ifs.open("SPIN1_DM");
std::string str((std::istreambuf_iterator<char>(ifs)),std::istreambuf_iterator<char>());
Expand Down
39 changes: 21 additions & 18 deletions source/module_io/write_dm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ void ModuleIO::write_dm(
int out_dm,
double*** DM,
const double& ef,
const UnitCell* ucell)
const UnitCell* ucell,
const int my_rank,
const int nspin,
const int nlocal)
{
ModuleBase::TITLE("ModuleIO","write_dm");

Expand All @@ -60,7 +63,7 @@ void ModuleIO::write_dm(
time_t start, end;
std::ofstream ofs;

if(GlobalV::MY_RANK==0)
if(my_rank==0)
{
start = time(NULL);

Expand Down Expand Up @@ -101,10 +104,10 @@ void ModuleIO::write_dm(
}
}

ofs << "\n " << GlobalV::NSPIN;
ofs << "\n " << nspin;
ofs << "\n " << ef << " (fermi energy)";

ofs << "\n " << GlobalV::NLOCAL << " " << GlobalV::NLOCAL << std::endl;
ofs << "\n " << nlocal << " " << nlocal << std::endl;

ofs << std::setprecision(precision);
ofs << std::scientific;
Expand All @@ -114,9 +117,9 @@ void ModuleIO::write_dm(
//ofs << "\n " << GlobalV::GAMMA_ONLY_LOCAL << " (GAMMA ONLY LOCAL)" << std::endl;
#ifndef __MPI

for(int i=0; i<GlobalV::NLOCAL; ++i)
for(int i=0; i<nlocal; ++i)
{
for(int j=0; j<GlobalV::NLOCAL; ++j)
for(int j=0; j<nlocal; ++j)
{
if(j%8==0) ofs << "\n";
ofs << " " << DM[is][i][j];
Expand All @@ -126,16 +129,16 @@ void ModuleIO::write_dm(
#else
//xiaohui modify 2014-06-18

double* tmp = new double[GlobalV::NLOCAL];
int* count = new int[GlobalV::NLOCAL];
for (int i=0; i<GlobalV::NLOCAL; ++i)
double* tmp = new double[nlocal];
int* count = new int[nlocal];
for (int i=0; i<nlocal; ++i)
{
// when reduce, there may be 'redundance', we need to count them.
ModuleBase::GlobalFunc::ZEROS(count, GlobalV::NLOCAL);
ModuleBase::GlobalFunc::ZEROS(count, nlocal);
const int mu = trace_lo[i];
if (mu >= 0)
{
for (int j=0; j<GlobalV::NLOCAL; ++j)
for (int j=0; j<nlocal; ++j)
{
const int nu = trace_lo[j];
if (nu >= 0)
Expand All @@ -144,13 +147,13 @@ void ModuleIO::write_dm(
}
}
}
Parallel_Reduce::reduce_all(count, GlobalV::NLOCAL);
Parallel_Reduce::reduce_all(count, nlocal);

// reduce the density matrix for 'i' line.
ModuleBase::GlobalFunc::ZEROS(tmp, GlobalV::NLOCAL);
ModuleBase::GlobalFunc::ZEROS(tmp, nlocal);
if (mu >= 0)
{
for (int j=0; j<GlobalV::NLOCAL; j++)
for (int j=0; j<nlocal; j++)
{
const int nu = trace_lo[j];
if (nu >=0)
Expand All @@ -160,11 +163,11 @@ void ModuleIO::write_dm(
}
}
}
Parallel_Reduce::reduce_all(tmp, GlobalV::NLOCAL);
Parallel_Reduce::reduce_all(tmp, nlocal);

if(GlobalV::MY_RANK==0)
if(my_rank==0)
{
for (int j=0; j<GlobalV::NLOCAL; j++)
for (int j=0; j<nlocal; j++)
{
if(j%8==0) ofs << "\n";
if(count[j]>0)
Expand All @@ -181,7 +184,7 @@ void ModuleIO::write_dm(
delete[] tmp;
delete[] count;
#endif
if(GlobalV::MY_RANK==0)
if(my_rank==0)
{
end = time(NULL);
ModuleBase::GlobalFunc::OUT_TIME("write_dm",start,end);
Expand Down

0 comments on commit 7089f23

Please sign in to comment.