Skip to content

Commit

Permalink
Fix bad aliasing in dbcsr_multiply
Browse files Browse the repository at this point in the history
  • Loading branch information
abussy committed Aug 25, 2020
1 parent b1360cd commit e3fa6e9
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 50 deletions.
13 changes: 8 additions & 5 deletions src/dm_ls_scf.F
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ SUBROUTINE post_scf_homo_lumo(ls_scf_env)
LOGICAL :: converged
REAL(KIND=dp) :: eps_max, eps_min, homo, lumo
TYPE(cp_logger_type), POINTER :: logger
TYPE(dbcsr_type) :: matrix_k, matrix_p
TYPE(dbcsr_type) :: matrix_k, matrix_p, matrix_tmp

CALL timeset(routineN, handle)

Expand All @@ -898,11 +898,13 @@ SUBROUTINE post_scf_homo_lumo(ls_scf_env)

CALL dbcsr_create(matrix_k, template=ls_scf_env%matrix_p(1), matrix_type=dbcsr_type_no_symmetry)

CALL dbcsr_create(matrix_tmp, template=ls_scf_env%matrix_p(1), matrix_type=dbcsr_type_no_symmetry)

DO ispin = 1, nspin
! ortho basis ks
CALL dbcsr_multiply("N", "N", 1.0_dp, ls_scf_env%matrix_s_sqrt_inv, ls_scf_env%matrix_ks(ispin), &
0.0_dp, matrix_k, filter_eps=ls_scf_env%eps_filter)
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_k, ls_scf_env%matrix_s_sqrt_inv, &
0.0_dp, matrix_tmp, filter_eps=ls_scf_env%eps_filter)
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_tmp, ls_scf_env%matrix_s_sqrt_inv, &
0.0_dp, matrix_k, filter_eps=ls_scf_env%eps_filter)

! extremal eigenvalues ks
Expand All @@ -911,8 +913,8 @@ SUBROUTINE post_scf_homo_lumo(ls_scf_env)

! ortho basis p
CALL dbcsr_multiply("N", "N", 1.0_dp, ls_scf_env%matrix_s_sqrt, ls_scf_env%matrix_p(ispin), &
0.0_dp, matrix_p, filter_eps=ls_scf_env%eps_filter)
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p, ls_scf_env%matrix_s_sqrt, &
0.0_dp, matrix_tmp, filter_eps=ls_scf_env%eps_filter)
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_tmp, ls_scf_env%matrix_s_sqrt, &
0.0_dp, matrix_p, filter_eps=ls_scf_env%eps_filter)
IF (nspin == 1) CALL dbcsr_scale(matrix_p, 0.5_dp)

Expand All @@ -924,6 +926,7 @@ SUBROUTINE post_scf_homo_lumo(ls_scf_env)

CALL dbcsr_release(matrix_p)
CALL dbcsr_release(matrix_k)
CALL dbcsr_release(matrix_tmp)

CALL timestop(handle)

Expand Down
22 changes: 11 additions & 11 deletions src/dm_ls_scf_methods.F
Original file line number Diff line number Diff line change
Expand Up @@ -582,13 +582,12 @@ SUBROUTINE density_matrix_sign_fixed_mu(matrix_p, trace, mu, sign_method, sign_o
0.0_dp, matrix_tmp, filter_eps=threshold)
CALL dbcsr_add(matrix_tmp, matrix_p_ud, 1.0_dp, -1.0_dp)
frob_matrix = dbcsr_frobenius_norm(matrix_tmp)
CALL dbcsr_release(matrix_tmp)
IF (unit_nr > 0) WRITE (unit_nr, '(T2,A,F20.12)') "Deviation from idempotency: ", frob_matrix

IF (sign_symmetric) THEN
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_sqrt_inv, matrix_p_ud, &
0.0_dp, matrix_p, filter_eps=threshold)
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p, matrix_s_sqrt_inv, &
0.0_dp, matrix_tmp, filter_eps=threshold)
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_tmp, matrix_s_sqrt_inv, &
0.0_dp, matrix_p, filter_eps=threshold)
ELSE

Expand All @@ -597,6 +596,7 @@ SUBROUTINE density_matrix_sign_fixed_mu(matrix_p, trace, mu, sign_method, sign_o
0.0_dp, matrix_p, filter_eps=threshold)
ENDIF
CALL dbcsr_release(matrix_p_ud)
CALL dbcsr_release(matrix_tmp)

CALL timestop(handle)

Expand Down Expand Up @@ -692,14 +692,14 @@ SUBROUTINE density_matrix_sign_internal_mu(matrix_p, trace, mu, sign_method, mat
0.0_dp, matrix_tmp, filter_eps=threshold)
CALL dbcsr_add(matrix_tmp, matrix_p_ud, 1.0_dp, -1.0_dp)
frob_matrix = dbcsr_frobenius_norm(matrix_tmp)
CALL dbcsr_release(matrix_tmp)
IF (unit_nr > 0) WRITE (unit_nr, '(T2,A,F20.12)') "Deviation from idempotency: ", frob_matrix

CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_sqrt_inv, matrix_p_ud, &
0.0_dp, matrix_p, filter_eps=threshold)
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p, matrix_s_sqrt_inv, &
0.0_dp, matrix_tmp, filter_eps=threshold)
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_tmp, matrix_s_sqrt_inv, &
0.0_dp, matrix_p, filter_eps=threshold)
CALL dbcsr_release(matrix_p_ud)
CALL dbcsr_release(matrix_tmp)

CALL timestop(handle)

Expand Down Expand Up @@ -1196,14 +1196,14 @@ SUBROUTINE density_matrix_tc2(matrix_p, matrix_ks, matrix_s_sqrt_inv, &
occ_matrix = dbcsr_get_occupation(matrix_x)
IF (unit_nr > 0) WRITE (unit_nr, '(T6,A,I3,1X,1F10.8,1X,1F10.8)') 'Final TC2 iteration ', i, occ_matrix, ABS(nu(i))

CALL dbcsr_release(matrix_xsq)
CALL dbcsr_release(matrix_tmp)

! output to matrix_p, P = inv(S)^0.5 X inv(S)^0.5
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_x, matrix_s_sqrt_inv, &
0.0_dp, matrix_tmp, filter_eps=threshold)
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_sqrt_inv, matrix_tmp, &
0.0_dp, matrix_p, filter_eps=threshold)
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_sqrt_inv, matrix_p, &
0.0_dp, matrix_p, filter_eps=threshold)

CALL dbcsr_release(matrix_xsq)
CALL dbcsr_release(matrix_tmp)

! ALGO 3 from. SIAM DOI. 10.1137/130911585
X(1) = 1.0_dp
Expand Down
62 changes: 34 additions & 28 deletions src/iterate_matrix.F
Original file line number Diff line number Diff line change
Expand Up @@ -706,10 +706,10 @@ SUBROUTINE matrix_sign_Newton_Schulz(matrix_sign, matrix, threshold, sign_order)
CALL dbcsr_create(tmp1, template=matrix_sign)
CALL dbcsr_create(tmp2, template=matrix_sign)
IF (order .GE. 4) THEN
IF (ABS(order) .GE. 4) THEN
CALL dbcsr_create(tmp3, template=matrix_sign)
ENDIF
IF (order .GE. 7) THEN
IF (ABS(order) .GT. 4) THEN
CALL dbcsr_create(tmp4, template=matrix_sign)
ENDIF
Expand Down Expand Up @@ -819,13 +819,13 @@ SUBROUTINE matrix_sign_Newton_Schulz(matrix_sign, matrix, threshold, sign_order)
CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 48.0_dp/35.0_dp)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp2, &
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp4, &
filter_eps=threshold, flop=flops)
floptot = floptot + flops
CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 8.0_dp/7.0_dp)
CALL dbcsr_add(tmp1, tmp4, 1.0_dp, 8.0_dp/7.0_dp)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 1.0_dp, tmp1, &
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp3, 1.0_dp, tmp1, &
filter_eps=threshold, flop=flops)
floptot = floptot + flops
Expand All @@ -848,17 +848,18 @@ SUBROUTINE matrix_sign_Newton_Schulz(matrix_sign, matrix, threshold, sign_order)
! tmp3=z
CALL dbcsr_copy(tmp3, tmp1)
CALL dbcsr_add_on_diag(tmp3, a0)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp1, 0.0_dp, tmp3, &
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp1, 0.0_dp, tmp2, &
filter_eps=threshold, flop=flops)
floptot = floptot + flops
CALL dbcsr_add_on_diag(tmp3, a1)
CALL dbcsr_add_on_diag(tmp2, a1)
CALL dbcsr_add_on_diag(tmp1, a2)
CALL dbcsr_add(tmp1, tmp3, 1.0_dp, 1.0_dp)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp3, 0.0_dp, tmp1, &
CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 1.0_dp)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, 0.0_dp, tmp3, &
filter_eps=threshold, flop=flops)
floptot = floptot + flops
CALL dbcsr_add_on_diag(tmp1, a3)
CALL dbcsr_add_on_diag(tmp3, a3)
CALL dbcsr_copy(tmp1, tmp3)
prefactor = 35.0_dp/128.0_dp
ELSE IF (order .EQ. 6) THEN
Expand All @@ -876,13 +877,13 @@ SUBROUTINE matrix_sign_Newton_Schulz(matrix_sign, matrix, threshold, sign_order)
CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 32.0_dp/21.0_dp)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp2, &
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp4, &
filter_eps=threshold, flop=flops)
floptot = floptot + flops
CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 80.0_dp/63.0_dp)
CALL dbcsr_add(tmp1, tmp4, 1.0_dp, 80.0_dp/63.0_dp)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp2, &
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp3, 0.0_dp, tmp2, &
filter_eps=threshold, flop=flops)
floptot = floptot + flops
Expand Down Expand Up @@ -911,22 +912,22 @@ SUBROUTINE matrix_sign_Newton_Schulz(matrix_sign, matrix, threshold, sign_order)
! tmp3=z
CALL dbcsr_copy(tmp3, tmp1)
CALL dbcsr_add_on_diag(tmp3, a0)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp1, 0.0_dp, tmp3, &
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp1, 0.0_dp, tmp2, &
filter_eps=threshold, flop=flops)
floptot = floptot + flops
CALL dbcsr_add_on_diag(tmp3, a1)
CALL dbcsr_add_on_diag(tmp2, a1)
! tmp4=w
CALL dbcsr_copy(tmp4, tmp1)
CALL dbcsr_add_on_diag(tmp4, a2)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp3, 0.0_dp, tmp4, &
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp2, 0.0_dp, tmp3, &
filter_eps=threshold, flop=flops)
floptot = floptot + flops
CALL dbcsr_add_on_diag(tmp4, a3)
CALL dbcsr_add_on_diag(tmp3, a3)
CALL dbcsr_add(tmp3, tmp4, 1.0_dp, 1.0_dp)
CALL dbcsr_add_on_diag(tmp3, a4)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp4, 0.0_dp, tmp1, &
CALL dbcsr_add(tmp2, tmp3, 1.0_dp, 1.0_dp)
CALL dbcsr_add_on_diag(tmp2, a4)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp1, &
filter_eps=threshold, flop=flops)
floptot = floptot + flops
CALL dbcsr_add_on_diag(tmp1, a5)
Expand Down Expand Up @@ -976,10 +977,10 @@ SUBROUTINE matrix_sign_Newton_Schulz(matrix_sign, matrix, threshold, sign_order)
CALL dbcsr_release(tmp1)
CALL dbcsr_release(tmp2)
IF (order .GE. 4) THEN
IF (ABS(order) .GE. 4) THEN
CALL dbcsr_release(tmp3)
ENDIF
IF (order .GE. 7) THEN
IF (ABS(order) .GT. 4) THEN
CALL dbcsr_release(tmp4)
ENDIF
Expand Down Expand Up @@ -1909,7 +1910,7 @@ SUBROUTINE matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, or
max_ev, min_ev, occ_matrix, scaling, &
t1, t2
TYPE(cp_logger_type), POINTER :: logger
TYPE(dbcsr_type) :: BK2A, matrixS, Rmat, tmp1, tmp2
TYPE(dbcsr_type) :: BK2A, matrixS, Rmat, tmp1, tmp2, tmp3

CALL cite_reference(Richters2018)

Expand All @@ -1934,6 +1935,7 @@ SUBROUTINE matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, or
! for stability symmetry can not be assumed
CALL dbcsr_create(tmp1, template=matrix, matrix_type=dbcsr_type_no_symmetry)
CALL dbcsr_create(tmp2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
CALL dbcsr_create(tmp3, template=matrix, matrix_type=dbcsr_type_no_symmetry)
CALL dbcsr_create(Rmat, template=matrix, matrix_type=dbcsr_type_no_symmetry)
CALL dbcsr_create(matrixS, template=matrix, matrix_type=dbcsr_type_no_symmetry)

Expand Down Expand Up @@ -1989,17 +1991,18 @@ SUBROUTINE matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, or
CALL dbcsr_copy(tmp2, Rmat)
ELSE
f = 0
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, Rmat, 0.0_dp, tmp2, &
CALL dbcsr_copy(tmp3, tmp2)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, Rmat, 0.0_dp, tmp2, &
filter_eps=threshold, flop=f)
flop3 = flop3 + f
ENDIF
CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 1.0_dp)
ENDDO
ELSE
CALL dbcsr_create(BK2A, template=matrix, matrix_type=dbcsr_type_no_symmetry)
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrixS, 0.0_dp, BK2A, &
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrixS, 0.0_dp, tmp3, &
filter_eps=threshold, flop=flop1)
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, BK2A, 0.0_dp, BK2A, &
CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, tmp3, 0.0_dp, BK2A, &
filter_eps=threshold, flop=flop2)
CALL dbcsr_copy(Rmat, BK2A)
CALL dbcsr_add_on_diag(Rmat, -1.0_dp)
Expand All @@ -2017,15 +2020,17 @@ SUBROUTINE matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, or
CALL dbcsr_add(tmp1, tmp2, 1.0_dp, -1.0_dp*(-1)**j*choose)
IF (j .LT. order) THEN
f = 0
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, BK2A, 0.0_dp, tmp2, &
CALL dbcsr_copy(tmp3, tmp2)
CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, BK2A, 0.0_dp, tmp2, &
filter_eps=threshold, flop=f)
flop3 = flop3 + f
ENDIF
ENDDO
CALL dbcsr_release(BK2A)
ENDIF

CALL dbcsr_multiply("N", "N", 0.5_dp, matrix_sqrt_inv, tmp1, 0.0_dp, matrix_sqrt_inv, &
CALL dbcsr_copy(tmp3, matrix_sqrt_inv)
CALL dbcsr_multiply("N", "N", 0.5_dp, tmp3, tmp1, 0.0_dp, matrix_sqrt_inv, &
filter_eps=threshold, flop=flop4)

occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
Expand Down Expand Up @@ -2103,6 +2108,7 @@ SUBROUTINE matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, or

CALL dbcsr_release(tmp1)
CALL dbcsr_release(tmp2)
CALL dbcsr_release(tmp3)
CALL dbcsr_release(Rmat)
CALL dbcsr_release(matrixS)

Expand Down
2 changes: 1 addition & 1 deletion src/rpa_axk.F
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ SUBROUTINE integrate_exchange(qs_env, dbcsr_Gamma_munu_P, mat_munu, para_env_sub
my_recalc_hfx_integrals = .FALSE.
! One more dbcsr multiplication and trace
CALL dbcsr_multiply("T", "N", 1.0_dp, mat_2d(1, 1)%matrix, dbcsr_Gamma_munu_P(aux)%matrix, &
CALL dbcsr_multiply("T", "N", 1.0_dp, mat_2d(1, 1)%matrix, rho_work_ao(1)%matrix, &
0.0_dp, dbcsr_Gamma_munu_P(aux)%matrix, filter_eps=eps_filter)
CALL dbcsr_trace(dbcsr_Gamma_munu_P(aux)%matrix, e_axk_p)
axk_corr = axk_corr + e_axk_P
Expand Down
12 changes: 7 additions & 5 deletions src/xas_tdp_utils.F
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ SUBROUTINE solve_xas_tdp_prob(donor_state, xas_tdp_control, xas_tdp_env, qs_env,
TYPE(cp_fm_type), POINTER :: c_diff, c_sum, lhs_matrix, lr_coeffs, &
rhs_matrix, work
TYPE(cp_para_env_type), POINTER :: para_env
TYPE(dbcsr_type) :: tmp_mat
TYPE(dbcsr_type) :: tmp_mat, tmp_mat2
TYPE(dbcsr_type), POINTER :: matrix_tdp

CALL timeset(routineN, handle)
Expand Down Expand Up @@ -582,9 +582,10 @@ SUBROUTINE solve_xas_tdp_prob(donor_state, xas_tdp_control, xas_tdp_env, qs_env,
! Need to multiply the current matrix_tdp with the auxiliary matrix
! tmp_mat = (A-D+E)^0.5 * M * (A-D+E)^0.5
CALL dbcsr_create(matrix=tmp_mat, template=matrix_tdp, matrix_type=dbcsr_type_no_symmetry)
CALL dbcsr_create(matrix=tmp_mat2, template=matrix_tdp, matrix_type=dbcsr_type_no_symmetry)
CALL dbcsr_multiply('N', 'N', 1.0_dp, donor_state%matrix_aux, matrix_tdp, &
0.0_dp, tmp_mat, filter_eps=xas_tdp_control%eps_filter)
CALL dbcsr_multiply('N', 'N', 1.0_dp, tmp_mat, donor_state%matrix_aux, &
0.0_dp, tmp_mat2, filter_eps=xas_tdp_control%eps_filter)
CALL dbcsr_multiply('N', 'N', 1.0_dp, tmp_mat2, donor_state%matrix_aux, &
0.0_dp, tmp_mat, filter_eps=xas_tdp_control%eps_filter)

! Get the matrix as a fm
Expand Down Expand Up @@ -620,8 +621,8 @@ SUBROUTINE solve_xas_tdp_prob(donor_state, xas_tdp_control, xas_tdp_env, qs_env,
! Get c_sum = (c^+ + c^-), which appears in all transition density related expressions
! c_sum = -1/omega G^-1 * (A-D+E) * (c^+ - c^-)
CALL dbcsr_multiply('N', 'N', 1.0_dp, donor_state%matrix_aux, donor_state%matrix_aux, &
0.0_dp, tmp_mat, filter_eps=xas_tdp_control%eps_filter)
CALL dbcsr_multiply('N', 'N', 1.0_dp, donor_state%metric(2)%matrix, tmp_mat, &
0.0_dp, tmp_mat2, filter_eps=xas_tdp_control%eps_filter)
CALL dbcsr_multiply('N', 'N', 1.0_dp, donor_state%metric(2)%matrix, tmp_mat2, &
0.0_dp, tmp_mat, filter_eps=xas_tdp_control%eps_filter)
CALL cp_dbcsr_sm_fm_multiply(tmp_mat, c_diff, c_sum, ncol=nrow)
WHERE (tmp_evals .NE. 0) scaling = -1.0_dp/tmp_evals
Expand All @@ -630,6 +631,7 @@ SUBROUTINE solve_xas_tdp_prob(donor_state, xas_tdp_control, xas_tdp_env, qs_env,
! Full TDDFT specific clean-up
CALL cp_fm_release(c_diff)
CALL dbcsr_release(tmp_mat)
CALL dbcsr_release(tmp_mat2)
DEALLOCATE (scaling)

END IF ! TDA
Expand Down

0 comments on commit e3fa6e9

Please sign in to comment.