Skip to content

Commit

Permalink
Adding nbnd_occ_kq to arguments of linear solver; allocating wavefunc…
Browse files Browse the repository at this point in the history
…tions up to this integer for evq (psi_k+q)
  • Loading branch information
gcistaro committed Apr 30, 2024
1 parent 16071f8 commit eee2fa4
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 34 deletions.
24 changes: 15 additions & 9 deletions src/api/sirius.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5958,12 +5958,13 @@ end subroutine sirius_create_H0
!> @param [in] num_spin_comp Number of spin components.
!> @param [in] alpha_pv Constant for the projector.
!> @param [in] spin Current spin channel.
!> @param [in] nbnd_occ Number of occupied bands.
!> @param [in] nbnd_occ_k Number of occupied bands at k.
!> @param [in] nbnd_occ_kq Number of occupied bands at k+q.
!> @param [in] tol Tolerance for the unconverged residuals (residual L2-norm should be below this value).
!> @param [out] niter Average number of iterations.
!> @param [out] error_code Error code
subroutine sirius_linear_solver(handler,vkq,num_gvec_kq_loc,gvec_kq_loc,dpsi,psi,&
&eigvals,dvpsi,ld,num_spin_comp,alpha_pv,spin,nbnd_occ,tol,niter,error_code)
&eigvals,dvpsi,ld,num_spin_comp,alpha_pv,spin,nbnd_occ_k,nbnd_occ_kq,tol,niter,error_code)
implicit none
!
type(sirius_ground_state_handler), target, intent(in) :: handler
Expand All @@ -5978,7 +5979,8 @@ subroutine sirius_linear_solver(handler,vkq,num_gvec_kq_loc,gvec_kq_loc,dpsi,psi
integer, target, intent(in) :: num_spin_comp
real(8), target, intent(in) :: alpha_pv
integer, target, intent(in) :: spin
integer, target, intent(in) :: nbnd_occ
integer, target, intent(in) :: nbnd_occ_k
integer, target, intent(in) :: nbnd_occ_kq
real(8), optional, target, intent(in) :: tol
integer, optional, target, intent(out) :: niter
integer, optional, target, intent(out) :: error_code
Expand All @@ -5995,14 +5997,15 @@ subroutine sirius_linear_solver(handler,vkq,num_gvec_kq_loc,gvec_kq_loc,dpsi,psi
type(C_PTR) :: num_spin_comp_ptr
type(C_PTR) :: alpha_pv_ptr
type(C_PTR) :: spin_ptr
type(C_PTR) :: nbnd_occ_ptr
type(C_PTR) :: nbnd_occ_k_ptr
type(C_PTR) :: nbnd_occ_kq_ptr
type(C_PTR) :: tol_ptr
type(C_PTR) :: niter_ptr
type(C_PTR) :: error_code_ptr
!
interface
subroutine sirius_linear_solver_aux(handler,vkq,num_gvec_kq_loc,gvec_kq_loc,dpsi,&
&psi,eigvals,dvpsi,ld,num_spin_comp,alpha_pv,spin,nbnd_occ,tol,niter,error_code)&
&psi,eigvals,dvpsi,ld,num_spin_comp,alpha_pv,spin,nbnd_occ_k,nbnd_occ_kq,tol,niter,error_code)&
&bind(C, name="sirius_linear_solver")
use, intrinsic :: ISO_C_BINDING
type(C_PTR), value :: handler
Expand All @@ -6017,7 +6020,8 @@ subroutine sirius_linear_solver_aux(handler,vkq,num_gvec_kq_loc,gvec_kq_loc,dpsi
type(C_PTR), value :: num_spin_comp
type(C_PTR), value :: alpha_pv
type(C_PTR), value :: spin
type(C_PTR), value :: nbnd_occ
type(C_PTR), value :: nbnd_occ_k
type(C_PTR), value :: nbnd_occ_kq
type(C_PTR), value :: tol
type(C_PTR), value :: niter
type(C_PTR), value :: error_code
Expand Down Expand Up @@ -6048,8 +6052,10 @@ subroutine sirius_linear_solver_aux(handler,vkq,num_gvec_kq_loc,gvec_kq_loc,dpsi
alpha_pv_ptr = C_LOC(alpha_pv)
spin_ptr = C_NULL_PTR
spin_ptr = C_LOC(spin)
nbnd_occ_ptr = C_NULL_PTR
nbnd_occ_ptr = C_LOC(nbnd_occ)
nbnd_occ_k_ptr = C_NULL_PTR
nbnd_occ_k_ptr = C_LOC(nbnd_occ_k)
nbnd_occ_kq_ptr = C_NULL_PTR
nbnd_occ_kq_ptr = C_LOC(nbnd_occ_kq)
tol_ptr = C_NULL_PTR
if (present(tol)) then
tol_ptr = C_LOC(tol)
Expand All @@ -6064,7 +6070,7 @@ subroutine sirius_linear_solver_aux(handler,vkq,num_gvec_kq_loc,gvec_kq_loc,dpsi
endif
call sirius_linear_solver_aux(handler_ptr,vkq_ptr,num_gvec_kq_loc_ptr,gvec_kq_loc_ptr,&
&dpsi_ptr,psi_ptr,eigvals_ptr,dvpsi_ptr,ld_ptr,num_spin_comp_ptr,alpha_pv_ptr,spin_ptr,&
&nbnd_occ_ptr,tol_ptr,niter_ptr,error_code_ptr)
&nbnd_occ_k_ptr,nbnd_occ_kq_ptr,tol_ptr,niter_ptr,error_code_ptr)
end subroutine sirius_linear_solver

!
Expand Down
61 changes: 37 additions & 24 deletions src/api/sirius_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6053,10 +6053,14 @@ sirius_create_H0(void* const* handler__, int* error_code__)
type: int
attr: in, required
doc: Current spin channel.
nbnd_occ:
nbnd_occ_k:
type: int
attr: in, required
doc: Number of occupied bands.
doc: Number of occupied bands at k.
nbnd_occ_kq:
type: int
attr: in, required
doc: Number of occupied bands at k+q.
tol:
type: double
attr: in, optional
Expand All @@ -6075,8 +6079,8 @@ void
sirius_linear_solver(void* const* handler__, double const* vkq__, int const* num_gvec_kq_loc__,
int const* gvec_kq_loc__, std::complex<double>* dpsi__, std::complex<double>* psi__,
double* eigvals__, std::complex<double>* dvpsi__, int const* ld__, int const* num_spin_comp__,
double const* alpha_pv__, int const* spin__, int const* nbnd_occ__, double const* tol__,
int* niter__, int* error_code__)
double const* alpha_pv__, int const* spin__, int const* nbnd_occ_k__, int const* nbnd_occ_kq__,
double const* tol__, int* niter__, int* error_code__)
{
using namespace sirius;
PROFILE("sirius_api::sirius_linear_solver");
Expand All @@ -6085,9 +6089,10 @@ sirius_linear_solver(void* const* handler__, double const* vkq__, int const* num
/* works for non-magnetic and collinear cases */
RTE_ASSERT(*num_spin_comp__ == 1);

int nbnd_occ = *nbnd_occ__;
int nbnd_occ_k = *nbnd_occ_k__;
int nbnd_occ_kq = *nbnd_occ_kq__;

if (nbnd_occ == 0) {
if (nbnd_occ_k == 0) {
return;
}

Expand Down Expand Up @@ -6123,29 +6128,36 @@ sirius_linear_solver(void* const* handler__, double const* vkq__, int const* num
auto Hk = H0(kp);

/* copy eigenvalues (factor 2 for rydberg/hartree) */
std::vector<double> eigvals_vec(eigvals__, eigvals__ + nbnd_occ);
std::vector<double> eigvals_vec(eigvals__, eigvals__ + nbnd_occ_k);
for (auto& val : eigvals_vec) {
val /= 2;
}

// Setup dpsi (unknown), psi (part of projector), and dvpsi (right-hand side)
mdarray<std::complex<double>, 3> psi({*ld__, *num_spin_comp__, nbnd_occ}, psi__);
mdarray<std::complex<double>, 3> dpsi({*ld__, *num_spin_comp__, nbnd_occ}, dpsi__);
mdarray<std::complex<double>, 3> dvpsi({*ld__, *num_spin_comp__, nbnd_occ}, dvpsi__);
mdarray<std::complex<double>, 3> psi({*ld__, *num_spin_comp__, nbnd_occ_kq}, psi__);
mdarray<std::complex<double>, 3> dpsi({*ld__, *num_spin_comp__, nbnd_occ_k}, dpsi__);
mdarray<std::complex<double>, 3> dvpsi({*ld__, *num_spin_comp__, nbnd_occ_k}, dvpsi__);

auto dpsi_wf = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ),
auto dpsi_wf = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ_k),
wf::num_mag_dims(0), false);
auto psi_wf = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ),
auto psi_wf = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ_kq),
wf::num_mag_dims(0), false);
auto dvpsi_wf = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ),
auto dvpsi_wf = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ_k),
wf::num_mag_dims(0), false);
auto tmp_wf = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ),
auto tmp_wf = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ_k),
wf::num_mag_dims(0), false);

for (int ispn = 0; ispn < *num_spin_comp__; ispn++) {
for (int i = 0; i < nbnd_occ; i++) {
for (int i = 0; i < nbnd_occ_kq; i++) {
for (int ig = 0; ig < kp.gkvec().count(); ig++) {
psi_wf->pw_coeffs(ig, wf::spin_index(ispn), wf::band_index(i)) = psi(ig, ispn, i);
}
}
}

for (int ispn = 0; ispn < *num_spin_comp__; ispn++) {
for (int i = 0; i < nbnd_occ_k; i++) {
for (int ig = 0; ig < kp.gkvec().count(); ig++) {
dpsi_wf->pw_coeffs(ig, wf::spin_index(ispn), wf::band_index(i)) = dpsi(ig, ispn, i);
// divide by two to account for hartree / rydberg, this is
// dv * psi and dv should be 2x smaller in sirius.
Expand All @@ -6159,19 +6171,19 @@ sirius_linear_solver(void* const* handler__, double const* vkq__, int const* num
sirius::K_point<double> kp(const_cast<sirius::Simulation_context&>(sctx), gvkq_in, 1.0);
kp.initialize();
auto Hk = H0(kp);
sirius::check_wave_functions<double, std::complex<double>>(
Hk, *psi_wf, sr, wf::band_range(0, nbnd_occ), eigvals_vec.data());
//sirius::check_wave_functions<double, std::complex<double>>(
// Hk, *psi_wf, sr, wf::band_range(0, nbnd_occ_kq), eigvals_vec.data());
}

/* setup auxiliary state vectors for CG */
auto U = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ), wf::num_mag_dims(0),
auto U = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ_k), wf::num_mag_dims(0),
false);
auto C = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ), wf::num_mag_dims(0),
auto C = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ_k), wf::num_mag_dims(0),
false);

auto Hphi_wf = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ),
auto Hphi_wf = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ_k),
wf::num_mag_dims(0), false);
auto Sphi_wf = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ),
auto Sphi_wf = sirius::wave_function_factory<double>(sctx, kp, wf::num_bands(nbnd_occ_k),
wf::num_mag_dims(0), false);

auto mem = sctx.processing_unit_memory_t();
Expand All @@ -6192,7 +6204,7 @@ sirius_linear_solver(void* const* handler__, double const* vkq__, int const* num
eigvals_vec, Hphi_wf.get(), Sphi_wf.get(),
psi_wf.get(), tmp_wf.get(),
*alpha_pv__ / 2, // rydberg/hartree factor
wf::band_range(0, nbnd_occ), sr, mem);
wf::band_range(0, nbnd_occ_kq), sr, mem);
/* CG state vectors */
auto X_wrap = sirius::lr::Wave_functions_wrap{dpsi_wf.get(), mem};
auto B_wrap = sirius::lr::Wave_functions_wrap{dvpsi_wf.get(), mem};
Expand All @@ -6211,7 +6223,7 @@ sirius_linear_solver(void* const* handler__, double const* vkq__, int const* num
sirius::lr::Smoothed_diagonal_preconditioner preconditioner{std::move(h_o_diag.first),
std::move(h_o_diag.second),
std::move(eigvals_mdarray),
nbnd_occ,
nbnd_occ_k,
mem,
sr};

Expand All @@ -6230,12 +6242,13 @@ sirius_linear_solver(void* const* handler__, double const* vkq__, int const* num

/* bring wave functions back in order of QE */
for (int ispn = 0; ispn < *num_spin_comp__; ispn++) {
for (int i = 0; i < nbnd_occ; i++) {
for (int i = 0; i < nbnd_occ_k; i++) {
for (int ig = 0; ig < kp.gkvec().count(); ig++) {
dpsi(ig, ispn, i) = dpsi_wf->pw_coeffs(ig, wf::spin_index(ispn), wf::band_index(i));
}
}
}
std::cout << "here i am \n";
},
error_code__);
}
Expand Down
2 changes: 1 addition & 1 deletion src/multi_cg/multi_cg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ struct Linear_response_operator
, br(br)
, sr(sr)
, mem(mem)
, overlap(br.size(), br.size())
, overlap(br.size(), Hphi->num_wf())//br.size())
{
// I think we could just compute alpha_pv here by just making it big enough
// s.t. the operator H - e * S + alpha_pv * Q is positive, e.g:
Expand Down

0 comments on commit eee2fa4

Please sign in to comment.