Skip to content

Commit

Permalink
Perf: optimize omp critical sections (#1492)
Browse files Browse the repository at this point in the history
  • Loading branch information
caic99 committed Nov 11, 2022
1 parent d4a1242 commit 9b79b4d
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions source/module_gint/gint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void Gint::cal_gint(Gint_inout *inout)
if(inout->job==Gint_Tools::job_type::tau) ModuleBase::TITLE("Gint_interface","cal_gint_tau");
if(inout->job==Gint_Tools::job_type::force) ModuleBase::TITLE("Gint_interface","cal_gint_force");
if(inout->job==Gint_Tools::job_type::force_meta) ModuleBase::TITLE("Gint_interface","cal_gint_force_meta");

if(inout->job==Gint_Tools::job_type::vlocal) ModuleBase::timer::tick("Gint_interface", "cal_gint_vlocal");
if(inout->job==Gint_Tools::job_type::vlocal_meta) ModuleBase::timer::tick("Gint_interface","cal_gint_vlocal_meta");
if(inout->job==Gint_Tools::job_type::rho) ModuleBase::timer::tick("Gint_interface","cal_gint_rho");
Expand All @@ -48,7 +48,7 @@ void Gint::cal_gint(Gint_inout *inout)
//prepare some constants
const int ncyz = GlobalC::rhopw->ny*GlobalC::rhopw->nplane; // mohan add 2012-03-25
const double dv = GlobalC::ucell.omega/this->ncxyz;

// it's a uniform grid to save orbital values, so the delta_r is a constant.
const double delta_r = GlobalC::ORB.dr_uniform;

Expand Down Expand Up @@ -105,7 +105,7 @@ void Gint::cal_gint(Gint_inout *inout)
// get the value: how many atoms has orbital value on this grid.
const int na_grid = GlobalC::GridT.how_many_atoms[ grid_index ];

if(na_grid==0) continue;
if(na_grid==0) continue;

if(inout->job == Gint_Tools::job_type::rho)
{
Expand All @@ -127,11 +127,11 @@ void Gint::cal_gint(Gint_inout *inout)
if(GlobalV::GAMMA_ONLY_LOCAL) DM_in = inout->DM[GlobalV::CURRENT_SPIN];
if(!GlobalV::GAMMA_ONLY_LOCAL) DM_in = inout->DM_R;
#ifdef _OPENMP
this->gint_kernel_force(na_grid, grid_index, delta_r, vldr3, LD_pool,
this->gint_kernel_force(na_grid, grid_index, delta_r, vldr3, LD_pool,
DM_in, inout->isforce, inout->isstress,
&fvl_dphi_thread, &svl_dphi_thread);
#else
this->gint_kernel_force(na_grid, grid_index, delta_r, vldr3, LD_pool,
this->gint_kernel_force(na_grid, grid_index, delta_r, vldr3, LD_pool,
DM_in, inout->isforce, inout->isstress,
inout->fvl_dphi, inout->svl_dphi);
#endif
Expand Down Expand Up @@ -191,52 +191,51 @@ void Gint::cal_gint(Gint_inout *inout)
if(GlobalV::GAMMA_ONLY_LOCAL) DM_in = inout->DM[GlobalV::CURRENT_SPIN];
if(!GlobalV::GAMMA_ONLY_LOCAL) DM_in = inout->DM_R;
#ifdef _OPENMP
this->gint_kernel_force_meta(na_grid, grid_index, delta_r, vldr3, vkdr3, LD_pool,
this->gint_kernel_force_meta(na_grid, grid_index, delta_r, vldr3, vkdr3, LD_pool,
DM_in, inout->isforce, inout->isstress,
&fvl_dphi_thread, &svl_dphi_thread);
#else
this->gint_kernel_force_meta(na_grid, grid_index, delta_r, vldr3, vkdr3, LD_pool,
this->gint_kernel_force_meta(na_grid, grid_index, delta_r, vldr3, vkdr3, LD_pool,
DM_in, inout->isforce, inout->isstress,
inout->fvl_dphi, inout->svl_dphi);
#endif
delete[] vldr3;
delete[] vkdr3;
}
}
} // int grid_index

#ifdef _OPENMP
if(inout->job==Gint_Tools::job_type::vlocal || inout->job==Gint_Tools::job_type::vlocal_meta)
{
if(GlobalV::GAMMA_ONLY_LOCAL && lgd>0)
{
#pragma omp critical(gint_gamma)
for(int i=0;i<lgd*lgd;i++)
{
#pragma omp critical(gint_gamma)
pvpR_grid[i] += pvpR_thread[i];
}
delete[] pvpR_thread;
}
if(!GlobalV::GAMMA_ONLY_LOCAL)
{
#pragma omp critical(gint_k)
for(int innrg=0; innrg<GlobalC::GridT.nnrg; innrg++)
{
#pragma omp critical(gint_k)
pvpR_reduced[inout->ispin][innrg] += pvpR_thread[innrg];
}
delete[] pvpR_thread;
}
}

#pragma omp critical(gint)
if(inout->job==Gint_Tools::job_type::force || inout->job==Gint_Tools::job_type::force_meta)
{
if(inout->isforce)
{
#pragma omp critical(gint)
inout->fvl_dphi[0]+=fvl_dphi_thread;
}
if(inout->isstress)
{
#pragma omp critical(gint)
inout->svl_dphi[0]+=svl_dphi_thread;
}
}
Expand All @@ -246,7 +245,7 @@ void Gint::cal_gint(Gint_inout *inout)
#ifdef __MKL
mkl_set_num_threads(mkl_threads);
#endif
} // end of if (max_size)
} // end of if (max_size)

ModuleBase::timer::tick("Gint_interface", "cal_gint");

Expand All @@ -267,7 +266,7 @@ void Gint::prep_grid(
const int& ncxyz_in)
{
ModuleBase::TITLE(GlobalV::ofs_running,"Gint_k","prep_grid");

this->nbx = nbx_in;
this->nby = nby_in;
this->nbz = nbz_in;
Expand All @@ -281,4 +280,4 @@ void Gint::prep_grid(
assert( GlobalC::ucell.omega > 0.0);

return;
}
}

0 comments on commit 9b79b4d

Please sign in to comment.