Skip to content

Commit

Permalink
QM/MM: Improve performance by using standard sum in pw_integral_ab
Browse files Browse the repository at this point in the history
  • Loading branch information
holly-t committed Dec 7, 2021
1 parent ffa6531 commit 83b1487
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 45 deletions.
132 changes: 98 additions & 34 deletions src/pw/pw_methods.F
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ MODULE pw_methods
USE fft_tools, ONLY: BWFFT,&
FWFFT,&
fft3d
USE kahan_sum, ONLY: accurate_dot_product,&
accurate_sum
USE kahan_sum, ONLY: accurate_sum
USE kinds, ONLY: dp
USE machine, ONLY: m_memory
USE message_passing, ONLY: mp_sum
Expand Down Expand Up @@ -73,6 +72,8 @@ MODULE pw_methods

CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'pw_methods'
LOGICAL, PARAMETER, PRIVATE :: debug_this_module = .FALSE.
INTEGER, PARAMETER, PUBLIC :: do_accurate_sum = 0, &
do_standard_sum = 1

INTERFACE pw_gather
MODULE PROCEDURE pw_gather_s, pw_gather_p
Expand Down Expand Up @@ -2069,56 +2070,119 @@ END FUNCTION pw_compatible
!> only returns the real part of it ......
!> \param pw1 ...
!> \param pw2 ...
!> \param sumtype ...
!> \return ...
!> \par History
!> JGH (14-Mar-2001) : Parallel sum and some tests, HALFSPACE case
!> \author apsi
! **************************************************************************************************
FUNCTION pw_integral_ab(pw1, pw2) RESULT(integral_value)
FUNCTION pw_integral_ab(pw1, pw2, sumtype) RESULT(integral_value)

TYPE(pw_type), INTENT(IN) :: pw1, pw2
INTEGER, INTENT(IN), OPTIONAL :: sumtype
REAL(KIND=dp) :: integral_value

CHARACTER(len=*), PARAMETER :: routineN = 'pw_integral_ab'

INTEGER :: handle
INTEGER :: handle, loc_sumtype

CALL timeset(routineN, handle)

loc_sumtype = do_accurate_sum
IF (PRESENT(sumtype)) loc_sumtype = sumtype

IF (pw1%pw_grid%id_nr /= pw2%pw_grid%id_nr) THEN
CPABORT("Grids incompatible")
END IF

! since the return value is real, only do accurate sum on the real bit ?
IF (pw1%in_use == REALDATA3D .AND. pw2%in_use == REALDATA3D) THEN
integral_value = accurate_dot_product(pw1%cr3d(:, :, :), &
pw2%cr3d(:, :, :))
ELSE IF (pw1%in_use == REALDATA3D &
.AND. pw2%in_use == COMPLEXDATA3D) THEN
integral_value = REAL(accurate_sum(pw1%cr3d(:, :, :)* &
pw2%cc3d(:, :, :)), KIND=dp) !? complex bit
ELSE IF (pw1%in_use == COMPLEXDATA3D &
.AND. pw2%in_use == REALDATA3D) THEN
integral_value = REAL(accurate_sum(pw1%cc3d(:, :, :)* &
pw2%cr3d(:, :, :)), KIND=dp) !? complex bit
ELSE IF (pw1%in_use == COMPLEXDATA3D &
.AND. pw2%in_use == COMPLEXDATA3D) THEN
integral_value = REAL(accurate_sum(CONJG(pw1%cc3d(:, :, :)) &
*pw2%cc3d(:, :, :)), KIND=dp) !? complex bit

ELSE IF (pw1%in_use == REALDATA1D &
.AND. pw2%in_use == REALDATA1D) THEN
integral_value = accurate_dot_product(pw1%cr(:), pw2%cr(:))
ELSE IF (pw1%in_use == REALDATA1D &
.AND. pw2%in_use == COMPLEXDATA1D) THEN
integral_value = REAL(accurate_sum(pw1%cr(:)*pw2%cc(:)), KIND=dp) !? complex bit
ELSE IF (pw1%in_use == COMPLEXDATA1D &
.AND. pw2%in_use == REALDATA1D) THEN
integral_value = REAL(accurate_sum(pw1%cc(:)*pw2%cr(:)), KIND=dp) !? complex bit
ELSE IF (pw1%in_use == COMPLEXDATA1D &
.AND. pw2%in_use == COMPLEXDATA1D) THEN
integral_value = REAL(accurate_sum(CONJG(pw1%cc(:))*pw2%cc(:)), KIND=dp) !? complex bit
! do standard sum
IF (loc_sumtype == do_standard_sum) THEN

! since the return value is real, only do accurate sum on the real bit ?
IF (pw1%in_use == REALDATA3D .AND. pw2%in_use == REALDATA3D) THEN
!$OMP PARALLEL WORKSHARE DEFAULT(NONE) SHARED(pw1, pw2, integral_value)
integral_value = SUM(pw1%cr3d(:, :, :) &
*pw2%cr3d(:, :, :))
!$OMP END PARALLEL WORKSHARE
ELSE IF (pw1%in_use == REALDATA3D &
.AND. pw2%in_use == COMPLEXDATA3D) THEN
!$OMP PARALLEL WORKSHARE DEFAULT(NONE) SHARED(pw1, pw2, integral_value)
integral_value = REAL(SUM(pw1%cr3d(:, :, :) &
*pw2%cc3d(:, :, :)), KIND=dp) !? complex bit
!$OMP END PARALLEL WORKSHARE
ELSE IF (pw1%in_use == COMPLEXDATA3D &
.AND. pw2%in_use == REALDATA3D) THEN
!$OMP PARALLEL WORKSHARE DEFAULT(NONE) SHARED(pw1, pw2, integral_value)
integral_value = REAL(SUM(pw1%cc3d(:, :, :) &
*pw2%cr3d(:, :, :)), KIND=dp) !? complex bit
!$OMP END PARALLEL WORKSHARE
ELSE IF (pw1%in_use == COMPLEXDATA3D &
.AND. pw2%in_use == COMPLEXDATA3D) THEN
!$OMP PARALLEL WORKSHARE DEFAULT(NONE) SHARED(pw1, pw2, integral_value)
integral_value = REAL(SUM(CONJG(pw1%cc3d(:, :, :)) &
*pw2%cc3d(:, :, :)), KIND=dp) !? complex bit
!$OMP END PARALLEL WORKSHARE

ELSE IF (pw1%in_use == REALDATA1D &
.AND. pw2%in_use == REALDATA1D) THEN
!$OMP PARALLEL WORKSHARE DEFAULT(NONE) SHARED(pw1, pw2, integral_value)
integral_value = DOT_PRODUCT(pw1%cr(:), pw2%cr(:))
!$OMP END PARALLEL WORKSHARE
ELSE IF (pw1%in_use == REALDATA1D &
.AND. pw2%in_use == COMPLEXDATA1D) THEN
!$OMP PARALLEL WORKSHARE DEFAULT(NONE) SHARED(pw1, pw2, integral_value)
integral_value = REAL(DOT_PRODUCT(pw1%cr(:), pw2%cc(:)), KIND=dp) !? complex bit
!$OMP END PARALLEL WORKSHARE
ELSE IF (pw1%in_use == COMPLEXDATA1D &
.AND. pw2%in_use == REALDATA1D) THEN
!$OMP PARALLEL WORKSHARE DEFAULT(NONE) SHARED(pw1, pw2, integral_value)
integral_value = REAL(DOT_PRODUCT(pw1%cc(:), pw2%cr(:)), KIND=dp) !? complex bit
!$OMP END PARALLEL WORKSHARE
ELSE IF (pw1%in_use == COMPLEXDATA1D &
.AND. pw2%in_use == COMPLEXDATA1D) THEN
!$OMP PARALLEL WORKSHARE DEFAULT(NONE) SHARED(pw1, pw2, integral_value)
integral_value = REAL(DOT_PRODUCT(CONJG(pw1%cc(:)), CONJG(pw2%cc(:))), KIND=dp) !? complex bit
!$OMP END PARALLEL WORKSHARE
ELSE
CPABORT("No possible DATA")
END IF

! do accurate sum
ELSE
CPABORT("No possible DATA")

! since the return value is real, only do accurate sum on the real bit ?
IF (pw1%in_use == REALDATA3D .AND. pw2%in_use == REALDATA3D) THEN
integral_value = accurate_sum(pw1%cr3d(:, :, :) &
*pw2%cr3d(:, :, :))
ELSE IF (pw1%in_use == REALDATA3D &
.AND. pw2%in_use == COMPLEXDATA3D) THEN
integral_value = REAL(accurate_sum(pw1%cr3d(:, :, :) &
*pw2%cc3d(:, :, :)), KIND=dp) !? complex bit
ELSE IF (pw1%in_use == COMPLEXDATA3D &
.AND. pw2%in_use == REALDATA3D) THEN
integral_value = REAL(accurate_sum(pw1%cc3d(:, :, :) &
*pw2%cr3d(:, :, :)), KIND=dp) !? complex bit
ELSE IF (pw1%in_use == COMPLEXDATA3D &
.AND. pw2%in_use == COMPLEXDATA3D) THEN
integral_value = REAL(accurate_sum(CONJG(pw1%cc3d(:, :, :)) &
*pw2%cc3d(:, :, :)), KIND=dp) !? complex bit

ELSE IF (pw1%in_use == REALDATA1D &
.AND. pw2%in_use == REALDATA1D) THEN
integral_value = accurate_sum(pw1%cr(:)*pw2%cr(:))
ELSE IF (pw1%in_use == REALDATA1D &
.AND. pw2%in_use == COMPLEXDATA1D) THEN
integral_value = REAL(accurate_sum(pw1%cr(:)*pw2%cc(:)), KIND=dp) !? complex bit
ELSE IF (pw1%in_use == COMPLEXDATA1D &
.AND. pw2%in_use == REALDATA1D) THEN
integral_value = REAL(accurate_sum(pw1%cc(:)*pw2%cr(:)), KIND=dp) !? complex bit
ELSE IF (pw1%in_use == COMPLEXDATA1D &
.AND. pw2%in_use == COMPLEXDATA1D) THEN
integral_value = REAL(accurate_sum(CONJG(pw1%cc(:))*pw2%cc(:)), KIND=dp) !? complex bit
ELSE
CPABORT("No possible DATA")
END IF

END IF

IF (pw1%in_use == REALDATA3D .OR. pw1%in_use == COMPLEXDATA3D) THEN
Expand Down
14 changes: 8 additions & 6 deletions src/pw/pw_spline_utils.F
Original file line number Diff line number Diff line change
Expand Up @@ -2701,11 +2701,12 @@ END SUBROUTINE pw_spline_do_precond
!> \param eps_r the requested precision on the residual
!> \param eps_x the requested precision on the solution
!> \param max_iter maximum number of iteration allowed
!> \param sumtype ...
!> \return ...
!> \author fawzi
! **************************************************************************************************
FUNCTION find_coeffs(values, coeffs, linOp, preconditioner, pool, &
eps_r, eps_x, max_iter) RESULT(res)
eps_r, eps_x, max_iter, sumtype) RESULT(res)
TYPE(pw_type), POINTER :: values, coeffs
INTERFACE
! **************************************************************************************************
Expand All @@ -2718,6 +2719,7 @@ END SUBROUTINE linOp
TYPE(pw_pool_type), POINTER :: pool
REAL(kind=dp), INTENT(in) :: eps_r, eps_x
INTEGER, INTENT(in) :: max_iter
INTEGER, INTENT(in), OPTIONAL :: sumtype
LOGICAL :: res

INTEGER :: i, iiter, iter, j, k
Expand Down Expand Up @@ -2745,21 +2747,21 @@ END SUBROUTINE linOp
CALL pw_axpy(values, r)
CALL pw_spline_do_precond(preconditioner, in_v=r, out_v=z)
CALL pw_copy(z, p)
r_z = pw_integral_ab(r, z)
r_z = pw_integral_ab(r, z, sumtype)

DO iter = iiter, MIN(iiter + 9, max_iter)
eps_r_att = SQRT(pw_integral_ab(r, r))
eps_r_att = SQRT(pw_integral_ab(r, r, sumtype))
IF (eps_r_att == 0._dp) THEN
eps_x_att = 0._dp
last = .TRUE.
ELSE
CALL pw_zero(Ap)
CALL linOp(pw_in=p, pw_out=Ap)
alpha = r_z/pw_integral_ab(Ap, p)
alpha = r_z/pw_integral_ab(Ap, p, sumtype)

CALL pw_axpy(p, coeffs, alpha=alpha)

eps_x_att = alpha*SQRT(pw_integral_ab(p, p)) ! try to spare if unneeded?
eps_x_att = alpha*SQRT(pw_integral_ab(p, p, sumtype)) ! try to spare if unneeded?
IF (eps_r_att < eps_r .AND. eps_x_att < eps_x) last = .TRUE.
END IF
!CALL cp_iterate(logger%iter_info,last=last)
Expand All @@ -2772,7 +2774,7 @@ END SUBROUTINE linOp

CALL pw_spline_do_precond(preconditioner, in_v=r, out_v=z)

r_z_new = pw_integral_ab(r, z)
r_z_new = pw_integral_ab(r, z, sumtype)
beta = r_z_new/r_z
r_z = r_z_new

Expand Down
11 changes: 6 additions & 5 deletions src/pw_env/cp_spline_utils.F
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ MODULE cp_spline_utils
USE input_section_types, ONLY: section_vals_type,&
section_vals_val_get
USE kinds, ONLY: dp
USE pw_methods, ONLY: pw_axpy,&
USE pw_methods, ONLY: do_standard_sum,&
pw_axpy,&
pw_zero
USE pw_pool_types, ONLY: pw_pool_create_pw,&
pw_pool_give_back_pw,&
Expand Down Expand Up @@ -125,11 +126,11 @@ SUBROUTINE pw_restrict_s3(pw_fine_in, pw_coarse_out, coarse_pool, param_section)
IF (pbc) THEN
success = find_coeffs(values=values, coeffs=coeffs, &
linOp=spl3_pbc, preconditioner=precond, pool=coarse_pool, &
eps_r=eps_r, eps_x=eps_x, max_iter=max_iter)
eps_r=eps_r, eps_x=eps_x, max_iter=max_iter, sumtype=do_standard_sum)
ELSE
success = find_coeffs(values=values, coeffs=coeffs, &
linOp=spl3_nopbct, preconditioner=precond, pool=coarse_pool, &
eps_r=eps_r, eps_x=eps_x, max_iter=max_iter)
eps_r=eps_r, eps_x=eps_x, max_iter=max_iter, sumtype=do_standard_sum)
END IF
CALL pw_spline_precond_release(precond)

Expand Down Expand Up @@ -200,12 +201,12 @@ SUBROUTINE pw_prolongate_s3(pw_coarse_in, pw_fine_out, coarse_pool, &
success = find_coeffs(values=pw_coarse_in, coeffs=coeffs, &
linOp=spl3_pbc, preconditioner=precond, pool=coarse_pool, &
eps_r=eps_r, eps_x=eps_x, &
max_iter=max_iter)
max_iter=max_iter, sumtype=do_standard_sum)
ELSE
success = find_coeffs(values=pw_coarse_in, coeffs=coeffs, &
linOp=spl3_nopbc, preconditioner=precond, pool=coarse_pool, &
eps_r=eps_r, eps_x=eps_x, &
max_iter=max_iter)
max_iter=max_iter, sumtype=do_standard_sum)
END IF
CPASSERT(success)
CALL pw_spline_precond_release(precond)
Expand Down

0 comments on commit 83b1487

Please sign in to comment.