Skip to content

Commit

Permalink
Feature: Support GPU workflow for CG method (#1502)
Browse files Browse the repository at this point in the history
* Enable GPU support for CG method

* track changes

* fix memory leak

* ready for merge

* address comments
  • Loading branch information
denghuilu committed Nov 14, 2022
1 parent ffa6931 commit de0a130
Show file tree
Hide file tree
Showing 53 changed files with 596 additions and 415 deletions.
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ namespace ModuleESolver

hsolver::HSolver* phsol = nullptr;
elecstate::ElecState* pelec = nullptr;
hamilt::Hamilt* p_hamilt = nullptr;
hamilt::Hamilt<double>* p_hamilt = nullptr;
ModulePW::PW_Basis_K* pw_wfc = nullptr;
Charge_Extra CE;

Expand Down
6 changes: 3 additions & 3 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ namespace ModuleESolver
//delete Hamilt
if(this->p_hamilt != nullptr)
{
delete (hamilt::HamiltPW*)this->p_hamilt;
delete (hamilt::HamiltPW<double>*)this->p_hamilt;
this->p_hamilt = nullptr;
}
}
Expand Down Expand Up @@ -198,13 +198,13 @@ namespace ModuleESolver
//delete Hamilt if not first scf
if(this->p_hamilt != nullptr)
{
delete (hamilt::HamiltPW*)this->p_hamilt;
delete (hamilt::HamiltPW<double>*)this->p_hamilt;
this->p_hamilt = nullptr;
}
//allocate HamiltPW
if(this->p_hamilt == nullptr)
{
this->p_hamilt = new hamilt::HamiltPW();
this->p_hamilt = new hamilt::HamiltPW<double>();
}

//----------------------------------------------------------
Expand Down
6 changes: 4 additions & 2 deletions source/module_hamilt/hamilt.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
namespace hamilt
{

template<typename FPTYPE, typename Device = psi::DEVICE_CPU>
class Hamilt
{
public:
Expand All @@ -32,8 +33,9 @@ class Hamilt
int non_first_scf=0;

// first node operator, add operations from each operators
Operator<std::complex<double>>* ops = nullptr;
Operator<double>* opsd = nullptr;
Operator<std::complex<FPTYPE>, Device>* ops = nullptr;
Operator<double, Device>* opsd = nullptr;

};

} // namespace hamilt
Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt/hamilt_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace hamilt

// template first for type of k space H matrix elements
// template second for type of temporary matrix, gamma_only fix-gamma-matrix + S-gamma, multi-k fix-Real + S-Real
template <typename T> class HamiltLCAO : public Hamilt
template <typename T> class HamiltLCAO : public Hamilt<double>
{
public:
HamiltLCAO(
Expand Down
99 changes: 84 additions & 15 deletions source/module_hamilt/hamilt_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#include "module_base/blas_connector.h"
#include "module_base/global_function.h"
#include "module_base/global_variable.h"
#include "module_base/timer.h"
#include "src_parallel/parallel_reduce.h"
#include "src_pw/global.h"

#include "ks_pw/veff_pw.h"
Expand All @@ -15,7 +13,8 @@
namespace hamilt
{

HamiltPW::HamiltPW()
template<typename FPTYPE, typename Device>
HamiltPW<FPTYPE, Device>::HamiltPW()
{
this->classname = "HamiltPW";
const double tpiba2 = GlobalC::ucell.tpiba2;
Expand All @@ -26,7 +25,7 @@ HamiltPW::HamiltPW()
if (GlobalV::T_IN_H)
{
// Operator<double>* ekinetic = new Ekinetic<OperatorLCAO<double>>
Operator<std::complex<double>>* ekinetic = new Ekinetic<OperatorPW<double>>(
Operator<std::complex<FPTYPE>, Device>* ekinetic = new Ekinetic<OperatorPW<FPTYPE, Device>>(
tpiba2,
gk2,
GlobalC::wfcpw->nks,
Expand All @@ -43,7 +42,7 @@ HamiltPW::HamiltPW()
}
if (GlobalV::VL_IN_H)
{
Operator<std::complex<double>>* veff = new Veff<OperatorPW<double>>(
Operator<std::complex<FPTYPE>, Device>* veff = new Veff<OperatorPW<FPTYPE, Device>>(
isk,
&(GlobalC::pot.vr_eff),
GlobalC::wfcpw
Expand All @@ -59,7 +58,7 @@ HamiltPW::HamiltPW()
}
if (GlobalV::VNL_IN_H)
{
Operator<std::complex<double>>* nonlocal = new Nonlocal<OperatorPW<double>>(
Operator<std::complex<FPTYPE>, Device>* nonlocal = new Nonlocal<OperatorPW<FPTYPE, Device>>(
isk,
&GlobalC::ppcell,
&GlobalC::ucell
Expand All @@ -73,7 +72,7 @@ HamiltPW::HamiltPW()
this->ops->add(nonlocal);
}
}
Operator<std::complex<double>>* meta = new Meta<OperatorPW<double>>(
Operator<std::complex<FPTYPE>, Device>* meta = new Meta<OperatorPW<FPTYPE, Device>>(
tpiba,
isk,
&GlobalC::pot.vofk,
Expand All @@ -89,32 +88,102 @@ HamiltPW::HamiltPW()
}
}

HamiltPW::~HamiltPW()
template<typename FPTYPE, typename Device>
HamiltPW<FPTYPE, Device>::~HamiltPW()
{
if(this->ops!= nullptr)
{
delete this->ops;
}
}

void HamiltPW::updateHk(const int ik)
template<typename FPTYPE, typename Device>
void HamiltPW<FPTYPE, Device>::updateHk(const int ik)
{
ModuleBase::TITLE("HamiltPW","updateHk");

this->ops->init(ik);

return;
ModuleBase::TITLE("HamiltPW","updateHk");
}

void HamiltPW::sPsi
template<typename FPTYPE, typename Device>
void HamiltPW<FPTYPE, Device>::sPsi
(
const std::complex<double> *psi,
std::complex<double> *spsi,
size_t size
) const
{
ModuleBase::GlobalFunc::COPYARRAY(psi, spsi, size);
return;
// ModuleBase::GlobalFunc::COPYARRAY(psi, spsi, size);
// denghui replaced at 2022.11.04
syncmem_complex_op()(this->ctx, this->ctx, spsi, psi, size);
}

template<typename FPTYPE, typename Device>
template<typename T_in, typename Device_in>
HamiltPW<FPTYPE, Device>::HamiltPW(const HamiltPW<T_in, Device_in> *hamilt)
{
this->classname = hamilt->classname;
OperatorPW<std::complex<T_in>, Device_in> * node =
reinterpret_cast<OperatorPW<std::complex<T_in>, Device_in> *>(hamilt->ops);

while(node != nullptr) {
if (node->classname == "Ekinetic") {
Operator<std::complex<FPTYPE>, Device>* ekinetic =
new Ekinetic<OperatorPW<FPTYPE, Device>>(
reinterpret_cast<const Ekinetic<OperatorPW<T_in, Device_in>>*>(node));
if(this->ops == nullptr) {
this->ops = ekinetic;
}
else {
this->ops->add(ekinetic);
}
// this->ops = reinterpret_cast<Operator<std::complex<FPTYPE>, Device>*>(node);
}
else if (node->classname == "Nonlocal") {
Operator<std::complex<FPTYPE>, Device>* nonlocal =
new Nonlocal<OperatorPW<FPTYPE, Device>>(
reinterpret_cast<const Nonlocal<OperatorPW<T_in, Device_in>>*>(node));
if(this->ops == nullptr) {
this->ops = nonlocal;
}
else {
this->ops->add(nonlocal);
}
}
else if (node->classname == "Veff") {
Operator<std::complex<FPTYPE>, Device>* veff =
new Veff<OperatorPW<FPTYPE, Device>>(
reinterpret_cast<const Veff<OperatorPW<T_in, Device_in>>*>(node));
if(this->ops == nullptr) {
this->ops = veff;
}
else {
this->ops->add(veff);
}
}
else if (node->classname == "Meta") {
Operator<std::complex<FPTYPE>, Device>* meta =
new Meta<OperatorPW<FPTYPE, Device>>(
reinterpret_cast<const Meta<OperatorPW<T_in, Device_in>>*>(node));
if(this->ops == nullptr) {
this->ops = meta;
}
else {
this->ops->add(meta);
}
}
else {
ModuleBase::WARNING_QUIT("HamiltPW", "Unrecognized Operator type!");
}
node = reinterpret_cast<OperatorPW<std::complex<T_in>, Device_in> *>(node->next_op);
}
}

template class HamiltPW<double, psi::DEVICE_CPU>;
#if ((defined __CUDA) || (defined __ROCM))
template class HamiltPW<double, psi::DEVICE_GPU>;
template HamiltPW<double, psi::DEVICE_CPU>::HamiltPW(const HamiltPW<double, psi::DEVICE_GPU> *hamilt);
template HamiltPW<double, psi::DEVICE_GPU>::HamiltPW(const HamiltPW<double, psi::DEVICE_CPU> *hamilt);
#endif

} // namespace hamilt
8 changes: 7 additions & 1 deletion source/module_hamilt/hamilt_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
namespace hamilt
{

class HamiltPW : public Hamilt
template<typename FPTYPE, typename Device = psi::DEVICE_CPU>
class HamiltPW : public Hamilt<FPTYPE, Device>
{
public:
HamiltPW();
template<typename T_in, typename Device_in = Device>
explicit HamiltPW(const HamiltPW<T_in, Device_in>* hamilt);
~HamiltPW();

// for target K point, update consequence of hPsi() and matrix()
Expand All @@ -19,6 +22,9 @@ class HamiltPW : public Hamilt
virtual void sPsi(const std::complex<double> *psi_in, std::complex<double> *spsi, const size_t size) const override;

private:

Device *ctx = {};
using syncmem_complex_op = psi::memory::synchronize_memory_op<std::complex<FPTYPE>, Device, Device>;
};

} // namespace hamilt
Expand Down
24 changes: 24 additions & 0 deletions source/module_hamilt/ks_pw/ekinetic_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Ekinetic<OperatorPW<FPTYPE, Device>>::Ekinetic(
const int gk2_row,
const int gk2_col)
{
this->classname = "Ekinetic";
this->cal_type = pw_ekinetic;
this->tpiba2 = tpiba2_in;
this->gk2_row = gk2_row;
Expand Down Expand Up @@ -61,9 +62,32 @@ void Ekinetic<OperatorPW<FPTYPE, Device>>::act(
ModuleBase::timer::tick("Operator", "EkineticPW");
}

// copy construct added by denghui at 20221105
template<typename FPTYPE, typename Device>
template<typename T_in, typename Device_in>
hamilt::Ekinetic<OperatorPW<FPTYPE, Device>>::Ekinetic(const Ekinetic<OperatorPW<T_in, Device_in>> *ekinetic) {
this->classname = "Ekinetic";
this->cal_type = pw_ekinetic;
this->ik = ekinetic->get_ik();
this->tpiba2 = ekinetic->get_tpiba2();
this->gk2_row = ekinetic->get_gk2_row();
this->gk2_col = ekinetic->get_gk2_col();
resize_memory_op()(this->ctx, this->gk2, this->gk2_row * this->gk2_col);
psi::memory::synchronize_memory_op<FPTYPE, Device, Device_in>()(
this->ctx, ekinetic->get_ctx(),
this->gk2, ekinetic->get_gk2(),
this->gk2_row * this->gk2_col);

if( this->tpiba2 < 1e-10 || this->gk2 == nullptr) {
ModuleBase::WARNING_QUIT("EkineticPW", "Copy Constuctor of Operator::EkineticPW is failed, please check your code!");
}
}

namespace hamilt{
template class Ekinetic<OperatorPW<double, psi::DEVICE_CPU>>;
#if ((defined __CUDA) || (defined __ROCM))
template class Ekinetic<OperatorPW<double, psi::DEVICE_GPU>>;
template Ekinetic<OperatorPW<double, psi::DEVICE_CPU>>::Ekinetic(const Ekinetic<OperatorPW<double, psi::DEVICE_GPU>> *ekinetic);
template Ekinetic<OperatorPW<double, psi::DEVICE_GPU>>::Ekinetic(const Ekinetic<OperatorPW<double, psi::DEVICE_CPU>> *ekinetic);
#endif
} // namespace hamilt
11 changes: 10 additions & 1 deletion source/module_hamilt/ks_pw/ekinetic_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class Ekinetic<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>
const int gk2_row,
const int gk2_col);

template<typename T_in, typename Device_in = Device>
explicit Ekinetic(const Ekinetic<OperatorPW<T_in, Device_in>>* ekinetic);

virtual ~Ekinetic();

virtual void act(
Expand All @@ -37,14 +40,20 @@ class Ekinetic<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi)const override;

// denghuilu added for copy construct at 20221105
int get_gk2_row() const {return this->gk2_row;}
int get_gk2_col() const {return this->gk2_col;}
FPTYPE get_tpiba2() const {return this->tpiba2;}
const FPTYPE* get_gk2() const {return this->gk2;}
Device* get_ctx() const {return this->ctx;}

private:

mutable int max_npw = 0;

mutable int npol = 0;

FPTYPE tpiba2 = 0.0;

#if ((defined __CUDA) || (defined __ROCM))
FPTYPE* gk2 = nullptr;
#else
Expand Down
19 changes: 19 additions & 0 deletions source/module_hamilt/ks_pw/meta_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Meta<OperatorPW<FPTYPE, Device>>::Meta(
ModulePW::PW_Basis_K* wfcpw_in
)
{
this->classname = "Meta";
this->cal_type = pw_meta;
this->isk = isk_in;
this->tpiba = tpiba_in;
Expand Down Expand Up @@ -85,9 +86,27 @@ void Meta<OperatorPW<FPTYPE, Device>>::act(
ModuleBase::timer::tick("Operator", "MetaPW");
}

template<typename FPTYPE, typename Device>
template<typename T_in, typename Device_in>
hamilt::Meta<OperatorPW<FPTYPE, Device>>::Meta(const Meta<OperatorPW<T_in, Device_in>> *meta) {
this->classname = "Meta";
this->cal_type = pw_meta;
this->ik = meta->get_ik();
this->isk = meta->get_isk();
this->tpiba = meta->get_tpiba();
this->vk = meta->get_vk();
this->wfcpw = meta->get_wfcpw();
if(this->isk == nullptr || this->tpiba < 1e-10 || this->vk == nullptr || this->wfcpw == nullptr)
{
ModuleBase::WARNING_QUIT("MetaPW", "Constuctor of Operator::MetaPW is failed, please check your code!");
}
}

namespace hamilt{
template class Meta<OperatorPW<double, psi::DEVICE_CPU>>;
#if ((defined __CUDA) || (defined __ROCM))
template class Meta<OperatorPW<double, psi::DEVICE_GPU>>;
template Meta<OperatorPW<double, psi::DEVICE_CPU>>::Meta(const Meta<OperatorPW<double, psi::DEVICE_GPU>> *meta);
template Meta<OperatorPW<double, psi::DEVICE_GPU>>::Meta(const Meta<OperatorPW<double, psi::DEVICE_CPU>> *meta);
#endif
} // namespace hamilt
9 changes: 9 additions & 0 deletions source/module_hamilt/ks_pw/meta_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class Meta<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>
public:
Meta(FPTYPE tpiba2_in, const int* isk_in, const ModuleBase::matrix* vk, ModulePW::PW_Basis_K* wfcpw);

template<typename T_in, typename Device_in = Device>
explicit Meta(const Meta<OperatorPW<T_in, Device_in>>* meta);

virtual ~Meta(){};

virtual void act(
Expand All @@ -32,6 +35,12 @@ class Meta<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>
std::complex<FPTYPE>* tmhpsi
)const override;

// denghui added for copy constructor at 20221105
FPTYPE get_tpiba() const {return this->tpiba;}
const int * get_isk() const {return this->isk;}
const ModuleBase::matrix* get_vk() const {return this->vk;}
ModulePW::PW_Basis_K* get_wfcpw() const {return this->wfcpw;}

private:

mutable int max_npw = 0;
Expand Down

0 comments on commit de0a130

Please sign in to comment.