Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: use less memory to calculate stress in pw base #4047

Open
wants to merge 51 commits into
base: develop
Choose a base branch
from

Conversation

dyzheng
Copy link
Collaborator

@dyzheng dyzheng commented Apr 23, 2024

I have refactored stress code structure in this PR.
In case Mg16Al16, the memory cost of stress calculation from 16752 MB to 194 MB.

Linked Issue

Close #3714
Close #4158
Close #3710
Close #4026
Close #3931
Close #4031

Unit Tests and/or Case Tests for my changes

  • A unit test is added for each new feature or bug fix.

What's changed?

  • Example: My changes might affect the performance of the application under certain conditions, and I have tested the impact on various scenarios...

Any changes of core modules? (ignore if not applicable)

  • Example: I have added a new virtual function in the esolver base class in order to ...

@Qianruipku
Copy link
Collaborator

Has the efficiency of the new algorithm been tested? Are there any test data available?

@dyzheng
Copy link
Collaborator Author

dyzheng commented Apr 24, 2024

Has the efficiency of the new algorithm been tested? Are there any test data available?

I have not tested many cases, in Mg16Al16 case , time of stress_nl change from 94 s to 124 s, I think the performance of new method still can be improved.

@Qianruipku
Copy link
Collaborator

Perhaps the QE code can be used as a reference.

if (gx[ig].norm2() > 1e-9) {
for(int ig = 0;ig< ngy;ig++)
{
FPTYPE norm2 = gx[ig * 3] * gx[ig * 3] + gx[ig * 3 + 1] * gx[ig * 3 + 1] + gx[ig * 3 + 2] * gx[ig * 3 + 2];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ig * 3 can be calculated and stored first

{
FPTYPE norm2 = gx[ig * 3] * gx[ig * 3] + gx[ig * 3 + 1] * gx[ig * 3 + 1] + gx[ig * 3 + 2] * gx[ig * 3 + 2];
dg [ig] = delta * sqrt(norm2) ;
if (norm2 > 1e-9) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why set this value to 1e-9?

@@ -49,7 +49,8 @@ void Stress_PW<FPTYPE, Device>::stress_us(ModuleBase::matrix& sigma,
ModuleBase::matrix dylmk0(ppcell_in->lmaxq * ppcell_in->lmaxq, npw);
for (int ipol = 0; ipol < 3; ipol++)
{
this->dylmr2(ppcell_in->lmaxq * ppcell_in->lmaxq, npw, rho_basis->gcar, dylmk0, ipol);
double* gcar_ptr = reinterpret_cast<double*>(rho_basis->gcar);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the type of gcar? do you have to use cast?

@@ -634,30 +815,32 @@ void Stress_Func<FPTYPE, Device>::dylmr2 (
// gx = g +/- dg


ModuleBase::Vector3<FPTYPE> *gx = new ModuleBase::Vector3<FPTYPE> [ngy];
FPTYPE *gx = new FPTYPE[3 * ngy];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use instead of new?

// loop over all G-vectors
for(int ig=0;ig<npw;ig++)
{
vkb_ptr[ig] -= 2.0 * ylm_ptr[ig] * vq_deri_ptr[ig] * sk_in[ig] * pref_in[ih]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ig * 3 can be calculated first? multiple of 2.0 can be done after this for loop

for (int ig = 0; ig < npw; ig++)
{
vq_ptr[ig] = this->Polynomial_Interpolation_nl(
GlobalC::ppcell.tab,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not use GlobalC

std::vector<FPTYPE> Stress_Func<FPTYPE, Device>::cal_vq_deri(int it, const FPTYPE* gk, int npw)
{
// calculate beta in G-space using an interpolation table
const int nbeta = GlobalC::ucell.atoms[it].ncpp.nbeta;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not use GlobalC

std::vector<FPTYPE> Stress_Func<FPTYPE, Device>::cal_vq(int it, const FPTYPE* gk, int npw)
{
// calculate beta in G-space using an interpolation table
const int nbeta = GlobalC::ucell.atoms[it].ncpp.nbeta;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not use GlobalC

gk[ig*3+2] = tmp.z;
FPTYPE norm = sqrt(tmp.norm2());
gk[3 * npw + ig] = norm * GlobalC::ucell.tpiba;
gk[4 * npw + ig] = norm<1e-8?0.0:1.0/norm*GlobalC::ucell.tpiba;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GlobalC likes a cancer

template <typename FPTYPE, typename Device>
std::vector<FPTYPE> Stress_Func<FPTYPE, Device>::cal_gk(int ik, ModulePW::PW_Basis_K* wfc_basis)
{
int npw = wfc_basis->npwk[ik];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a const

ModuleBase::timer::tick("Stress_Func","stress_nl");
}
// cal_gk
template <typename FPTYPE, typename Device>
std::vector<FPTYPE> Stress_Func<FPTYPE, Device>::cal_gk(int ik, ModulePW::PW_Basis_K* wfc_basis)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function need unit test?

}
// cal_vq
template <typename FPTYPE, typename Device>
std::vector<FPTYPE> Stress_Func<FPTYPE, Device>::cal_vq(int it, const FPTYPE* gk, int npw)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function need unit test?


// cal_vq_deri
template <typename FPTYPE, typename Device>
std::vector<FPTYPE> Stress_Func<FPTYPE, Device>::cal_vq_deri(int it, const FPTYPE* gk, int npw)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function need unit test?

}
// cal_ylm
template <typename FPTYPE, typename Device>
std::vector<FPTYPE> Stress_Func<FPTYPE, Device>::cal_ylm(int lmax, int npw, const FPTYPE* gk_in)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function need unit test?

// cal_vkb
// cpu version first, gpu version later
template <typename FPTYPE, typename Device>
void Stress_Func<FPTYPE, Device>::cal_vkb(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function need unit test?

// cal_vkb
// cpu version first, gpu version later
template <typename FPTYPE, typename Device>
void Stress_Func<FPTYPE, Device>::cal_vkb_deri(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function need unit test?

@dyzheng
Copy link
Collaborator Author

dyzheng commented May 6, 2024

I will work with @grysgreat to accelerate performance in GPU/DCU, change this PR to draft.

@dyzheng dyzheng marked this pull request as draft May 6, 2024 05:02
@@ -162,6 +181,23 @@ struct cal_stress_nl_op<FPTYPE, psi::DEVICE_GPU> {
FPTYPE* stress);
};

// cpu version first, gpu version later
template <typename FPTYPE>
struct cal_vkb_op<FPTYPE, psi::DEVICE_GPU>{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

struct cal_vkb_op<FPTYPE, psi::DEVICE_GPU>{

Did this function called by any other functions ?

@dyzheng
Copy link
Collaborator Author

dyzheng commented Jun 7, 2024

This PR has merged changes from @grysgreat and @Religious-J , not ready for review now, I will refactor it later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Features Needed The features are indeed needed, and developers should have sophisticated knowledge
Projects
None yet
8 participants