Skip to content

Commit

Permalink
add gpu support for hpsi(npol!=1) (#1490)
Browse files Browse the repository at this point in the history
  • Loading branch information
denghuilu committed Nov 11, 2022
1 parent 0a3d668 commit d4a1242
Show file tree
Hide file tree
Showing 15 changed files with 487 additions and 88 deletions.
30 changes: 30 additions & 0 deletions source/module_hamilt/include/nonlocal.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ struct nonlocal_pw_op {
const FPTYPE* deeq,
std::complex<FPTYPE>* ps,
const std::complex<FPTYPE>* becp);

void operator() (
const Device* dev,
const int& l1,
const int& l2,
const int& l3,
int& sum,
int& iat,
const int& nkb,
const int& deeq_x,
const int& deeq_y,
const int& deeq_z,
const std::complex<FPTYPE>* deeq_nc,
std::complex<FPTYPE>* ps,
const std::complex<FPTYPE>* becp);
};

#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
Expand All @@ -43,6 +58,21 @@ struct nonlocal_pw_op<FPTYPE, psi::DEVICE_GPU> {
const FPTYPE* deeq,
std::complex<FPTYPE>* ps,
const std::complex<FPTYPE>* becp);

void operator() (
const psi::DEVICE_GPU* dev,
const int& l1,
const int& l2,
const int& l3,
int& sum,
int& iat,
const int& nkb,
const int& deeq_x,
const int& deeq_y,
const int& deeq_z,
const std::complex<FPTYPE>* deeq_nc,
std::complex<FPTYPE>* ps,
const std::complex<FPTYPE>* becp);
};
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
} // namespace hamilt
Expand Down
14 changes: 14 additions & 0 deletions source/module_hamilt/include/veff.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ struct veff_pw_op {
const int& size,
std::complex<FPTYPE>* out,
const FPTYPE* in);

void operator() (
const Device* dev,
const int& size,
std::complex<FPTYPE>* out,
std::complex<FPTYPE>* out1,
const FPTYPE** in);
};

#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
Expand All @@ -23,6 +30,13 @@ struct veff_pw_op<FPTYPE, psi::DEVICE_GPU> {
const int& size,
std::complex<FPTYPE>* out,
const FPTYPE* in);

void operator() (
const psi::DEVICE_GPU* dev,
const int& size,
std::complex<FPTYPE>* out,
std::complex<FPTYPE>* out1,
const FPTYPE** in);
};
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
} // namespace hamilt
Expand Down
95 changes: 55 additions & 40 deletions source/module_hamilt/ks_pw/nonlocal_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,16 @@ Nonlocal<OperatorPW<FPTYPE, Device>>::Nonlocal(
this->isk = isk_in;
this->ppcell = ppcell_in;
this->ucell = ucell_in;
this->deeq = psi::device::get_device_type<Device>(this->ctx) == psi::GpuDevice ?
this->ppcell->d_deeq : // for GpuDevice
this->ppcell->deeq.ptr; // for CpuDevice
if (psi::device::get_device_type<Device>(this->ctx) == psi::GpuDevice) {
this->deeq = this->ppcell->d_deeq;
this->deeq_nc = this->ppcell->d_deeq_nc;
resize_memory_op()(this->ctx, this->vkb, this->ppcell->vkb.size);
}
else {
this->deeq = this->ppcell->deeq.ptr;
this->deeq_nc = this->ppcell->deeq_nc.ptr;
this->vkb = this->ppcell->vkb.c;
}
if( this->isk == nullptr || this->ppcell == nullptr || this->ucell == nullptr)
{
ModuleBase::WARNING_QUIT("NonlocalPW", "Constuctor of Operator::NonlocalPW is failed, please check your code!");
Expand All @@ -33,6 +40,9 @@ template<typename FPTYPE, typename Device>
Nonlocal<OperatorPW<FPTYPE, Device>>::~Nonlocal() {
delete_memory_op()(this->ctx, this->ps);
delete_memory_op()(this->ctx, this->becp);
if (psi::device::get_device_type<Device>(this->ctx) == psi::GpuDevice) {
delete_memory_op()(this->ctx, this->vkb);
}
}

template<typename FPTYPE, typename Device>
Expand Down Expand Up @@ -108,42 +118,44 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::add_nonlocal_pp(std::complex<FPTYPE>
}
else
{
#if defined(__CUDA) || defined(__ROCM)
ModuleBase::WARNING_QUIT("NonlocalPW", " gpu implementation of this->npol != 1 is not supported currently !!! ");
#endif
for (int it = 0; it < this->ucell->ntype; it++)
{
int psind = 0;
int becpind = 0;
std::complex<FPTYPE> becp1 = std::complex<FPTYPE>(0.0, 0.0);
std::complex<FPTYPE> becp2 = std::complex<FPTYPE>(0.0, 0.0);

const int nproj = this->ucell->atoms[it].ncpp.nh;
for (int ia = 0; ia < this->ucell->atoms[it].na; ia++)
{
// each atom has nproj, means this is with structure factor;
// each projector (each atom) must multiply coefficient
// with all the other projectors.
for (int ip = 0; ip < nproj; ip++)
{
for (int ip2 = 0; ip2 < nproj; ip2++)
{
for (int ib = 0; ib < m; ib+=2)
{
psind = (sum + ip2) * m + ib;
becpind = ib * nkb + sum + ip;
becp1 = becp[becpind];
becp2 = becp[becpind + nkb];
ps[psind] += this->ppcell->deeq_nc(0, iat, ip2, ip) * becp1
+ this->ppcell->deeq_nc(1, iat, ip2, ip) * becp2;
ps[psind + 1] += this->ppcell->deeq_nc(2, iat, ip2, ip) * becp1
+ this->ppcell->deeq_nc(3, iat, ip2, ip) * becp2;
} // end ib
} // end ih
} // end jh
sum += nproj;
++iat;
} // end na
// added by denghui at 20221109
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
nonlocal_op()(
this->ctx, // device context
this->ucell->atoms[it].na, m, nproj, // four loop size
sum, iat, nkb, // additional index params
this->ppcell->deeq_nc.getBound2(), this->ppcell->deeq_nc.getBound3(), this->ppcell->deeq_nc.getBound4(), // realArray operator()
this->deeq_nc, // array of data
this->ps, this->becp); // array of data
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// for (int ia = 0; ia < this->ucell->atoms[it].na; ia++)
// {
// // each atom has nproj, means this is with structure factor;
// // each projector (each atom) must multiply coefficient
// // with all the other projectors.
// for (int ib = 0; ib < m; ib+=2)
// {
// for (int ip2 = 0; ip2 < nproj; ip2++)
// {
// for (int ip = 0; ip < nproj; ip++)
// {
// psind = (sum + ip2) * m + ib;
// becpind = ib * nkb + sum + ip;
// becp1 = becp[becpind];
// becp2 = becp[becpind + nkb];
// ps[psind] += this->ppcell->deeq_nc(0, iat, ip2, ip) * becp1
// + this->ppcell->deeq_nc(1, iat, ip2, ip) * becp2;
// ps[psind + 1] += this->ppcell->deeq_nc(2, iat, ip2, ip) * becp1
// + this->ppcell->deeq_nc(3, iat, ip2, ip) * becp2;
// } // end ib
// } // end ih
// } // end jh
// sum += nproj;
// ++iat;
// } // end na
} // end nt
}

Expand All @@ -162,7 +174,7 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::add_nonlocal_pp(std::complex<FPTYPE>
this->npw,
this->ppcell->nkb,
&ModuleBase::ONE,
this->ppcell->vkb.c,
this->vkb,
this->ppcell->vkb.nc,
this->ps,
inc,
Expand Down Expand Up @@ -194,7 +206,7 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::add_nonlocal_pp(std::complex<FPTYPE>
npm,
this->ppcell->nkb,
&ModuleBase::ONE,
this->ppcell->vkb.c,
this->vkb,
this->ppcell->vkb.nc,
this->ps,
npm,
Expand Down Expand Up @@ -232,6 +244,9 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::act
this->max_npw = psi_in->get_nbasis() / psi_in->npol;
this->npol = psi_in->npol;

if (psi::device::get_device_type<Device>(this->ctx) == psi::GpuDevice) {
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, this->vkb, this->ppcell->vkb.c, this->ppcell->vkb.size);
}
if (this->ppcell->nkb > 0)
{
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
Expand All @@ -251,7 +266,7 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::act
this->npw,
nkb,
&ModuleBase::ONE,
this->ppcell->vkb.c,
this->vkb,
this->ppcell->vkb.nc,
tmpsi_in,
inc,
Expand Down Expand Up @@ -284,7 +299,7 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::act
npm,
this->npw,
&ModuleBase::ONE,
this->ppcell->vkb.c,
this->vkb,
this->ppcell->vkb.nc,
tmpsi_in,
this->max_npw,
Expand Down
6 changes: 5 additions & 1 deletion source/module_hamilt/ks_pw/nonlocal_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,21 @@ class Nonlocal<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>

const UnitCell* ucell = nullptr;

mutable std::complex<FPTYPE>* becp = nullptr;
mutable std::complex<FPTYPE> *ps = nullptr;
mutable std::complex<FPTYPE> *vkb = nullptr;
mutable std::complex<FPTYPE> *becp = nullptr;
Device* ctx = {};
psi::DEVICE_CPU* cpu_ctx = {};
FPTYPE * deeq = nullptr;
std::complex<FPTYPE> * deeq_nc = nullptr;
// using nonlocal_op = nonlocal_pw_op<FPTYPE, Device>;
using gemv_op = hsolver::gemv_op<FPTYPE, Device>;
using gemm_op = hsolver::gemm_op<FPTYPE, Device>;
using nonlocal_op = nonlocal_pw_op<FPTYPE, Device>;
using set_memory_op = psi::memory::set_memory_op<std::complex<FPTYPE>, Device>;
using resize_memory_op = psi::memory::resize_memory_op<std::complex<FPTYPE>, Device>;
using delete_memory_op = psi::memory::delete_memory_op<std::complex<FPTYPE>, Device>;
using syncmem_complex_h2d_op = psi::memory::synchronize_memory_op<std::complex<FPTYPE>, Device, psi::DEVICE_CPU>;
};

} // namespace hamilt
Expand Down
70 changes: 43 additions & 27 deletions source/module_hamilt/ks_pw/veff_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "module_base/timer.h"
#include "module_base/tool_quit.h"
#include "module_psi/include/device.h"

using hamilt::Veff;
using hamilt::OperatorPW;
Expand All @@ -19,10 +20,13 @@ Veff<OperatorPW<FPTYPE, Device>>::Veff(
this->veff = veff_in[0].c;
//note: "veff = nullptr" means that this core does not treat potential but still treats wf.
this->veff_col = veff_in[0].nc;
this->veff_row = veff_in[0].nr;
this->wfcpw = wfcpw_in;
resize_memory_op()(this->ctx, this->porter, this->wfcpw->nmaxgr);
if (this->npol != 1) {
resize_memory_op()(this->ctx, this->porter1, this->wfcpw->nmaxgr);
this->device = psi::device::get_device_type<Device>(this->ctx);
resize_memory_complex_op()(this->ctx, this->porter, this->wfcpw->nmaxgr);
resize_memory_complex_op()(this->ctx, this->porter1, this->wfcpw->nmaxgr);
if (this->device == psi::GpuDevice) {
resize_memory_double_op()(this->ctx, this->d_veff, this->veff_col * this->veff_row);
}
if (this->isk == nullptr || this->wfcpw == nullptr) {
ModuleBase::WARNING_QUIT("VeffPW", "Constuctor of Operator::VeffPW is failed, please check your code!");
Expand All @@ -32,9 +36,10 @@ Veff<OperatorPW<FPTYPE, Device>>::Veff(
template<typename FPTYPE, typename Device>
Veff<OperatorPW<FPTYPE, Device>>::~Veff()
{
delete_memory_op()(this->ctx, this->porter);
if (this->npol != 1) {
delete_memory_op()(this->ctx, this->porter1);
delete_memory_complex_op()(this->ctx, this->porter);
delete_memory_complex_op()(this->ctx, this->porter1);
if (psi::device::get_device_type<Device>(this->ctx) == psi::GpuDevice) {
delete_memory_double_op()(this->ctx, this->d_veff);
}
}

Expand Down Expand Up @@ -69,7 +74,13 @@ void Veff<OperatorPW<FPTYPE, Device>>::act(
// {
// porter[ir] *= current_veff[ir];
// }
veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
if (this->device == psi::GpuDevice) {
syncmem_double_h2d_op()(this->ctx, this->cpu_ctx, this->d_veff, this->veff, this->veff_col * this->veff_row);
veff_op()(this->ctx, this->veff_col, this->porter, this->d_veff + current_spin * this->veff_col);
}
else {
veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
}
}
// wfcpw->real2recip(porter, tmhpsi, this->ik, true);
wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true);
Expand All @@ -78,33 +89,38 @@ void Veff<OperatorPW<FPTYPE, Device>>::act(
{
// std::complex<FPTYPE> *porter1 = new std::complex<FPTYPE>[wfcpw->nmaxgr];
// fft to real space and doing things.
wfcpw->recip2real(tmpsi_in, this->porter, this->ik);
wfcpw->recip2real(tmpsi_in + this->max_npw, this->porter1, this->ik);
std::complex<FPTYPE> sup, sdown;
wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik);
wfcpw->recip_to_real(this->ctx, tmpsi_in + this->max_npw, this->porter1, this->ik);
if(this->veff_col != 0)
{
/// denghui added at 20221109
const FPTYPE* current_veff[4];
for(int is=0;is<4;is++)
{
current_veff[is] = this->veff + is * this->veff_col;
if (this->device == psi::GpuDevice) {
syncmem_double_h2d_op()(this->ctx, this->cpu_ctx, this->d_veff, this->veff, this->veff_col * this->veff_row);
}
for (int ir = 0; ir < this->veff_col; ir++)
{
sup = this->porter[ir] * (current_veff[0][ir] + current_veff[3][ir])
+ this->porter1[ir]
* (current_veff[1][ir]
- std::complex<FPTYPE>(0.0, 1.0) * current_veff[2][ir]);
sdown = this->porter1[ir] * (current_veff[0][ir] - current_veff[3][ir])
+ this->porter[ir]
* (current_veff[1][ir]
+ std::complex<FPTYPE>(0.0, 1.0) * current_veff[2][ir]);
this->porter[ir] = sup;
this->porter1[ir] = sdown;
for(int is = 0; is < 4; is++) {
current_veff[is] = this->device == psi::GpuDevice ?
this->d_veff + is * this->veff_col : // for GPU device
this->veff + is * this->veff_col ; // for CPU device
}
veff_op()(this->ctx, this->veff_col, this->porter, this->porter1, current_veff);
// std::complex<FPTYPE> sup, sdown;
// for (int ir = 0; ir < this->veff_col; ir++) {
// sup = this->porter[ir] * (current_veff[0][ir] + current_veff[3][ir])
// + this->porter1[ir]
// * (current_veff[1][ir]
// - std::complex<FPTYPE>(0.0, 1.0) * current_veff[2][ir]);
// sdown = this->porter1[ir] * (current_veff[0][ir] - current_veff[3][ir])
// + this->porter[ir]
// * (current_veff[1][ir]
// + std::complex<FPTYPE>(0.0, 1.0) * current_veff[2][ir]);
// this->porter[ir] = sup;
// this->porter1[ir] = sdown;
// }
}
// (3) fft back to G space.
wfcpw->real2recip(this->porter, tmhpsi, this->ik, true);
wfcpw->real2recip(this->porter1, tmhpsi + this->max_npw, this->ik, true);
wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true);
wfcpw->real_to_recip(this->ctx, this->porter1, tmhpsi + this->max_npw, this->ik, true);
}
tmhpsi += this->max_npw * this->npol;
tmpsi_in += this->max_npw * this->npol;
Expand Down
13 changes: 10 additions & 3 deletions source/module_hamilt/ks_pw/veff_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,23 @@ class Veff<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>
ModulePW::PW_Basis_K* wfcpw = nullptr;

Device* ctx = {};
psi::DEVICE_CPU* cpu_ctx = {};

int veff_col = 0;
int veff_row = 0;
FPTYPE *veff = nullptr;
FPTYPE *d_veff = nullptr;
std::complex<FPTYPE> *porter = nullptr;
std::complex<FPTYPE> *porter1 = nullptr;

psi::AbacusDevice_t device = {};
using veff_op = veff_pw_op<FPTYPE, Device>;

using resize_memory_op = psi::memory::resize_memory_op<std::complex<FPTYPE>, Device>;
using delete_memory_op = psi::memory::delete_memory_op<std::complex<FPTYPE>, Device>;
using resize_memory_double_op = psi::memory::resize_memory_op<FPTYPE, Device>;
using delete_memory_double_op = psi::memory::delete_memory_op<FPTYPE, Device>;

using resize_memory_complex_op = psi::memory::resize_memory_op<std::complex<FPTYPE>, Device>;
using delete_memory_complex_op = psi::memory::delete_memory_op<std::complex<FPTYPE>, Device>;
using syncmem_double_h2d_op = psi::memory::synchronize_memory_op<FPTYPE, Device, psi::DEVICE_CPU>;
};

} // namespace hamilt
Expand Down

0 comments on commit d4a1242

Please sign in to comment.