Skip to content

Commit

Permalink
GPU: Add multi device support for HPsi(veff_pw) (#1456)
Browse files Browse the repository at this point in the history
* add multi device support for hpsi(veff_pw)

* add UTs

* fix compilation errors with cuda environment

* remove cuda flags

* fix CI error

* fix Intel compilation error
  • Loading branch information
denghuilu committed Nov 2, 2022
1 parent d4634c5 commit 3239b99
Show file tree
Hide file tree
Showing 30 changed files with 997 additions and 121 deletions.
2 changes: 1 addition & 1 deletion source/module_base/test/memory_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class MemoryTest : public testing::Test
{
protected:
// definition according to ../memory.cpp
// definition according to ../memory_psi.cpp
double factor = 1.0 / 1024.0 / 1024.0; // MB
double complex_matrix_mem = 2*sizeof(double) * factor; // byte to MB
double double_mem = sizeof(double) * factor;
Expand Down
21 changes: 18 additions & 3 deletions source/module_elecstate/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ if(ENABLE_LCAO)
../../src_pw/structure_factor.cpp ../../src_pw/pw_complement.cpp
../../src_pw/klist.cpp ../../src_parallel/parallel_kpoints.cpp ../../src_pw/occupy.cpp
)
if(USE_CUDA)
target_link_libraries(EState_updaterhok_pw cufft)
endif()

install(DIRECTORY support DESTINATION ${CMAKE_CURRENT_BINARY_DIR})

Expand All @@ -35,6 +38,9 @@ if(ENABLE_LCAO)
../../src_pdiag/pdiag_common.cpp
../../src_io/output.cpp ../../src_pw/soc.cpp ../../src_io/read_rho.cpp
)
if(USE_CUDA)
target_link_libraries(EState_psiToRho_lcao cufft)
endif()
target_compile_definitions(EState_psiToRho_lcao PRIVATE __MPI)
install(FILES elecstate_lcao_parallel_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR})

Expand Down Expand Up @@ -83,9 +89,7 @@ add_library(
../../module_base/ylm.cpp
)

add_library(
planewave_serial
OBJECT
list(APPEND planewave_serial_srcs
../../module_pw/fft.cpp
../../module_pw/pw_basis.cpp
../../module_pw/pw_basis_k.cpp
Expand All @@ -96,6 +100,17 @@ add_library(
../../module_pw/pw_init.cpp
../../module_pw/pw_transform.cpp
../../module_pw/pw_transform_k.cpp
../../module_pw/src/pw_multi_device.cpp
)

if (USE_CUDA)
list(APPEND planewave_serial_srcs ../../module_pw/src/cuda/pw_multi_device.cu)
endif()

add_library(
planewave_serial
OBJECT
${planewave_serial_srcs}
)

if(ENABLE_COVERAGE)
Expand Down
5 changes: 3 additions & 2 deletions source/module_hamilt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ list(APPEND objects
hamilt_pw.cpp
src/ekinetic.cpp
src/nonlocal.cpp
src/veff.cpp
ks_pw/ekinetic_pw.cpp
ks_pw/veff_pw.cpp
ks_pw/nonlocal_pw.cpp
Expand All @@ -30,9 +31,9 @@ if(ENABLE_LCAO)
endif()

if (USE_CUDA)
list(APPEND objects src/cuda/ekinetic.cu src/cuda/nonlocal.cu)
list(APPEND objects src/cuda/ekinetic.cu src/cuda/nonlocal.cu src/cuda/veff.cu)
elseif(USE_ROCM)
list(APPEND objects src/rocm/ekinetic.cu src/cuda/nonlocal.cu)
list(APPEND objects src/rocm/ekinetic.cu src/cuda/nonlocal.cu src/rocm/veff.cu)
endif()

add_library(
Expand Down
29 changes: 29 additions & 0 deletions source/module_hamilt/include/veff.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef MODULE_HAMILT_VEFF_H
#define MODULE_HAMILT_VEFF_H

#include "module_psi/psi.h"
#include <complex>

namespace hamilt {
template <typename FPTYPE, typename Device>
struct veff_pw_op {
void operator() (
const Device* dev,
const int& size,
std::complex<FPTYPE>* out,
const FPTYPE* in);
};

#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
// Partially specialize functor for psi::GpuDevice.
template <typename FPTYPE>
struct veff_pw_op<FPTYPE, psi::DEVICE_GPU> {
void operator() (
const psi::DEVICE_GPU* dev,
const int& size,
std::complex<FPTYPE>* out,
const FPTYPE* in);
};
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
} // namespace hamilt
#endif //MODULE_HAMILT_VEFF_H
5 changes: 3 additions & 2 deletions source/module_hamilt/ks_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ list(APPEND operator_ks_pw_srcs
../operator.cpp
../src/ekinetic.cpp
../src/nonlocal.cpp
../src/veff.cpp
)

if (USE_CUDA)
list(APPEND operator_ks_pw_srcs ../src/cuda/ekinetic.cu ../src/cuda/nonlocal.cu)
list(APPEND operator_ks_pw_srcs ../src/cuda/ekinetic.cu ../src/cuda/nonlocal.cu ../src/cuda/veff.cu)
elseif(USE_ROCM)
list(APPEND operator_ks_pw_srcs ../src/rocm/ekinetic.cu ../src/rocm/nonlocal.cu)
list(APPEND operator_ks_pw_srcs ../src/rocm/ekinetic.cu ../src/rocm/nonlocal.cu ../src/rocm/veff.cu)
endif()

add_library(
Expand Down
9 changes: 8 additions & 1 deletion source/module_hamilt/ks_pw/nonlocal_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "module_base/timer.h"
#include "src_parallel/parallel_reduce.h"
#include "module_base/tool_quit.h"
#include "module_psi/include/device.h"

using hamilt::Nonlocal;
using hamilt::OperatorPW;
Expand All @@ -19,6 +20,9 @@ Nonlocal<OperatorPW<FPTYPE, Device>>::Nonlocal(
this->isk = isk_in;
this->ppcell = ppcell_in;
this->ucell = ucell_in;
this->deeq = psi::device::get_device_type<Device>(this->ctx) == psi::GpuDevice ?
this->ppcell->d_deeq : // for GpuDevice
this->ppcell->deeq.ptr; // for CpuDevice
if( this->isk == nullptr || this->ppcell == nullptr || this->ucell == nullptr)
{
ModuleBase::WARNING_QUIT("NonlocalPW", "Constuctor of Operator::NonlocalPW is failed, please check your code!");
Expand Down Expand Up @@ -78,7 +82,7 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::add_nonlocal_pp(std::complex<FPTYPE>
this->ucell->atoms[it].na, m, nproj, // four loop size
sum, iat, current_spin, nkb, // additional index params
this->ppcell->deeq.getBound2(), this->ppcell->deeq.getBound3(), this->ppcell->deeq.getBound4(), // realArray operator()
this->ppcell->deeq.ptr, // array of data
this->deeq, // array of data
this->ps, this->becp); // array of data
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// for (int ia = 0; ia < this->ucell->atoms[it].na; ia++)
Expand All @@ -104,6 +108,9 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::add_nonlocal_pp(std::complex<FPTYPE>
}
else
{
#if defined(__CUDA) || defined(__ROCM)
ModuleBase::WARNING_QUIT("NonlocalPW", " gpu implementation of this->npol != 1 is not supported currently !!! ");
#endif
for (int it = 0; it < this->ucell->ntype; it++)
{
int psind = 0;
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt/ks_pw/nonlocal_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class Nonlocal<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>
mutable std::complex<FPTYPE>* becp = nullptr;
mutable std::complex<FPTYPE> *ps = nullptr;
Device* ctx = {};
FPTYPE * deeq = nullptr;
// using nonlocal_op = nonlocal_pw_op<FPTYPE, Device>;
using gemv_op = hsolver::gemv_op<FPTYPE, Device>;
using gemm_op = hsolver::gemm_op<FPTYPE, Device>;
Expand Down
75 changes: 44 additions & 31 deletions source/module_hamilt/ks_pw/veff_pw.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "veff_pw.h"

#include "module_base/timer.h"
#include "src_pw/global.h"
#include "module_base/tool_quit.h"

using hamilt::Veff;
Expand All @@ -15,14 +14,29 @@ Veff<OperatorPW<FPTYPE, Device>>::Veff(
{
this->cal_type = pw_veff;
this->isk = isk_in;
this->veff = veff_in;
// this->veff = veff_in;
// TODO: add an GPU veff array
this->veff = veff_in[0].c;
this->veff_col = veff_in[0].nc;
this->wfcpw = wfcpw_in;
if( this->isk == nullptr || this->veff == nullptr || this->wfcpw == nullptr)
{
resize_memory_op()(this->ctx, this->porter, this->wfcpw->nmaxgr);
if (this->npol != 1) {
resize_memory_op()(this->ctx, this->porter1, this->wfcpw->nmaxgr);
}
if (this->isk == nullptr || this->veff == nullptr || this->wfcpw == nullptr) {
ModuleBase::WARNING_QUIT("VeffPW", "Constuctor of Operator::VeffPW is failed, please check your code!");
}
}

template<typename FPTYPE, typename Device>
Veff<OperatorPW<FPTYPE, Device>>::~Veff()
{
delete_memory_op()(this->ctx, this->porter);
if (this->npol != 1) {
delete_memory_op()(this->ctx, this->porter1);
}
}

template<typename FPTYPE, typename Device>
void Veff<OperatorPW<FPTYPE, Device>>::act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
Expand All @@ -37,65 +51,64 @@ void Veff<OperatorPW<FPTYPE, Device>>::act(
const int current_spin = this->isk[this->ik];
this->npol = psi_in->npol;

std::complex<FPTYPE> *porter = new std::complex<FPTYPE>[wfcpw->nmaxgr];
// std::complex<FPTYPE> *porter = new std::complex<FPTYPE>[wfcpw->nmaxgr];
for (int ib = 0; ib < n_npwx; ib += this->npol)
{
if (this->npol == 1)
{
wfcpw->recip2real(tmpsi_in, porter, this->ik);
// wfcpw->recip2real(tmpsi_in, porter, this->ik);
wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik);
// NOTICE: when MPI threads are larger than number of Z grids
// veff would contain nothing, and nothing should be done in real space
// but the 3DFFT can not be skipped, it will cause hanging
if(this->veff->nc != 0)
if(this->veff_col != 0)
{
const FPTYPE* current_veff = &(this->veff[0](current_spin, 0));
for (int ir = 0; ir < this->veff->nc; ++ir)
{
porter[ir] *= current_veff[ir];
}
// const FPTYPE* current_veff = &(this->veff[0](current_spin, 0));
// for (int ir = 0; ir < this->veff->nc; ++ir)
// {
// porter[ir] *= current_veff[ir];
// }
veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
}
wfcpw->real2recip(porter, tmhpsi, this->ik, true);
// wfcpw->real2recip(porter, tmhpsi, this->ik, true);
wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true);
}
else
{
std::complex<FPTYPE> *porter1 = new std::complex<FPTYPE>[wfcpw->nmaxgr];
// std::complex<FPTYPE> *porter1 = new std::complex<FPTYPE>[wfcpw->nmaxgr];
// fft to real space and doing things.
wfcpw->recip2real(tmpsi_in, porter, this->ik);
wfcpw->recip2real(tmpsi_in + this->max_npw, porter1, this->ik);
wfcpw->recip2real(tmpsi_in, this->porter, this->ik);
wfcpw->recip2real(tmpsi_in + this->max_npw, this->porter1, this->ik);
std::complex<FPTYPE> sup, sdown;
if(this->veff->nc != 0)
if(this->veff_col != 0)
{
const FPTYPE* current_veff[4];
for(int is=0;is<4;is++)
{
current_veff[is] = &(this->veff[0](is, 0));
current_veff[is] = this->veff + is * this->veff_col;
}
for (int ir = 0; ir < this->veff->nc; ir++)
for (int ir = 0; ir < this->veff_col; ir++)
{
sup = porter[ir] * (current_veff[0][ir] + current_veff[3][ir])
+ porter1[ir]
sup = this->porter[ir] * (current_veff[0][ir] + current_veff[3][ir])
+ this->porter1[ir]
* (current_veff[1][ir]
- std::complex<FPTYPE>(0.0, 1.0) * current_veff[2][ir]);
sdown = porter1[ir] * (current_veff[0][ir] - current_veff[3][ir])
+ porter[ir]
sdown = this->porter1[ir] * (current_veff[0][ir] - current_veff[3][ir])
+ this->porter[ir]
* (current_veff[1][ir]
+ std::complex<FPTYPE>(0.0, 1.0) * current_veff[2][ir]);
porter[ir] = sup;
porter1[ir] = sdown;
this->porter[ir] = sup;
this->porter1[ir] = sdown;
}
}
// (3) fft back to G space.
wfcpw->real2recip(porter, tmhpsi, this->ik, true);
wfcpw->real2recip(porter1, tmhpsi + this->max_npw, this->ik, true);

delete[] porter1;
wfcpw->real2recip(this->porter, tmhpsi, this->ik, true);
wfcpw->real2recip(this->porter1, tmhpsi + this->max_npw, this->ik, true);
}
tmhpsi += this->max_npw * this->npol;
tmpsi_in += this->max_npw * this->npol;
}
delete[] porter;
ModuleBase::timer::tick("Operator", "VeffPW");
return;
}

namespace hamilt{
Expand Down
16 changes: 14 additions & 2 deletions source/module_hamilt/ks_pw/veff_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "operator_pw.h"
#include "module_base/matrix.h"
#include "module_pw/pw_basis_k.h"
#include "module_hamilt/include/veff.h"

namespace hamilt {

Expand All @@ -23,7 +24,7 @@ class Veff<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>
public:
Veff(const int* isk_in,const ModuleBase::matrix* veff_in,ModulePW::PW_Basis_K* wfcpw_in);

virtual ~Veff(){};
virtual ~Veff();

virtual void act (
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
Expand All @@ -40,9 +41,20 @@ class Veff<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>

const int* isk = nullptr;

const ModuleBase::matrix* veff = nullptr;

ModulePW::PW_Basis_K* wfcpw = nullptr;

Device* ctx = {};

int veff_col = 0;
FPTYPE *veff = nullptr;
std::complex<FPTYPE> *porter = nullptr;
std::complex<FPTYPE> *porter1 = nullptr;

using veff_op = veff_pw_op<FPTYPE, Device>;

using resize_memory_op = psi::memory::resize_memory_op<std::complex<FPTYPE>, Device>;
using delete_memory_op = psi::memory::delete_memory_op<std::complex<FPTYPE>, Device>;
};

} // namespace hamilt
Expand Down
43 changes: 43 additions & 0 deletions source/module_hamilt/src/cuda/veff.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "module_hamilt/include/veff.h"
#include <complex>
#include <thrust/complex.h>
#include "cuda_runtime.h"

namespace hamilt{

#define THREADS_PER_BLOCK 256

template <typename FPTYPE>
__global__ void veff_pw(
const int size,
thrust::complex<FPTYPE>* out,
const FPTYPE* in)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if(idx >= size) {return;}
out[idx] *= in[idx];
}

template <typename FPTYPE>
void veff_pw_op<FPTYPE, psi::DEVICE_GPU>::operator() (
const psi::DEVICE_GPU* dev,
const int& size,
std::complex<FPTYPE>* out,
const FPTYPE* in)
{
const int block = (size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
veff_pw<FPTYPE><<<block, THREADS_PER_BLOCK>>>(
size, // control params
reinterpret_cast<thrust::complex<FPTYPE>*>(out), // array of data
in); // array of data
// cpu part:
// for (int ir = 0; ir < size; ++ir)
// {
// out[ir] *= in[ir];
// }
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
}

template struct veff_pw_op<double, psi::DEVICE_GPU>;

} // namespace hamilt
5 changes: 1 addition & 4 deletions source/module_hamilt/src/nonlocal.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
#include "module_hamilt/include/nonlocal.h"

#include <iomanip>
#include <iostream>

using namespace hamilt;
using namespace hamilt;

template <typename FPTYPE>
struct hamilt::nonlocal_pw_op<FPTYPE, psi::DEVICE_CPU> {
Expand Down

0 comments on commit 3239b99

Please sign in to comment.