Skip to content

Commit

Permalink
Use ASSOCIATE block
Browse files Browse the repository at this point in the history
  • Loading branch information
Frederick Stein authored and fstein93 committed Oct 25, 2021
1 parent c4defec commit ebc1209
Showing 1 changed file with 116 additions and 119 deletions.
235 changes: 116 additions & 119 deletions src/mp2_ri_gpw.F
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ SUBROUTINE mp2_ri_gpw_compute_en(Emp2_Cou, Emp2_EX, Emp2_S, Emp2_T, BIb_C, mp2_e
REAL(KIND=dp) :: amp_fac, mem_for_aK, mem_for_comm, mem_for_iaK, mem_for_rep, mem_min, &
mem_per_group, mem_real, my_Emp2_Cou, my_Emp2_EX, sym_fac, t_new, t_start
REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: external_ab, external_i_aL, local_ab, &
local_ba, my_local_i_aL, &
my_local_j_aL, t_ab
local_ba, t_ab
REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: BI_C_rec, local_i_aL, local_j_aL, &
Y_i_aP, Y_j_aP
TYPE(cp_para_env_type), POINTER :: para_env_exchange, para_env_P, &
Expand Down Expand Up @@ -369,164 +368,162 @@ SUBROUTINE mp2_ri_gpw_compute_en(Emp2_Cou, Emp2_EX, Emp2_S, Emp2_T, BIb_C, mp2_e
END DO
CALL timestop(handle3)

ALLOCATE (my_local_i_aL(dimen_RI, my_B_size(ispin)), my_local_j_aL(dimen_RI, my_B_size(jspin)))

! loop over the block elements
DO iiB = 1, my_block_size
DO jjB = 1, my_block_size
CALL timeset(routineN//"_expansion", handle3)

my_local_i_aL(:, :) = local_i_aL(:, :, iiB)
my_local_j_aL(:, :) = local_j_aL(:, :, iiB)

! calculate the integrals (ia|jb) strating from my local data ...
local_ab = 0.0_dp
IF ((my_alpha_beta_case) .AND. (calc_forces)) THEN
local_ba = 0.0_dp
END IF
CALL dgemm_counter_start(dgemm_counter)
CALL dgemm('T', 'N', my_B_size(ispin), my_B_size(jspin), dimen_RI, 1.0_dp, &
my_local_i_aL, dimen_RI, my_local_j_aL, dimen_RI, &
0.0_dp, local_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), :), my_B_size(ispin))
CALL dgemm_counter_stop(dgemm_counter, my_B_size(ispin), my_B_size(jspin), dimen_RI)
! Additional integrals only for alpha_beta case and forces
IF (my_alpha_beta_case .AND. calc_forces) THEN
local_ba(my_B_virtual_start(jspin):my_B_virtual_end(jspin), :) = &
TRANSPOSE(local_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), :))
END IF
! ... and from the other of my subgroup
DO proc_shift = 1, para_env_sub%num_pe - 1
proc_send = sub_proc_map(para_env_sub%mepos + proc_shift)
proc_receive = sub_proc_map(para_env_sub%mepos - proc_shift)

CALL get_group_dist(gd_B_virtual(ispin), proc_receive, rec_B_virtual_start, rec_B_virtual_end, rec_B_size)

ALLOCATE (external_i_aL(dimen_RI, rec_B_size))
external_i_aL = 0.0_dp

CALL mp_sendrecv(my_local_i_aL, proc_send, &
external_i_aL, proc_receive, &
para_env_sub%group)
ASSOCIATE (my_local_i_aL=>local_i_aL(:, :, iiB), my_local_j_aL=>local_j_aL(:, :, jjB))

! calculate the integrals (ia|jb) strating from my local data ...
local_ab = 0.0_dp
IF ((my_alpha_beta_case) .AND. (calc_forces)) THEN
local_ba = 0.0_dp
ENDIF
CALL dgemm_counter_start(dgemm_counter)
CALL dgemm('T', 'N', rec_B_size, my_B_size(jspin), dimen_RI, 1.0_dp, &
external_i_aL, dimen_RI, my_local_j_aL, dimen_RI, &
0.0_dp, local_ab(rec_B_virtual_start:rec_B_virtual_end, 1:my_B_size(jspin)), rec_B_size)
CALL dgemm_counter_stop(dgemm_counter, rec_B_size, my_B_size(jspin), dimen_RI)

DEALLOCATE (external_i_aL)
CALL dgemm('T', 'N', my_B_size(ispin), my_B_size(jspin), dimen_RI, 1.0_dp, &
my_local_i_aL, dimen_RI, my_local_j_aL, dimen_RI, &
0.0_dp, local_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), :), my_B_size(ispin))
CALL dgemm_counter_stop(dgemm_counter, my_B_size(ispin), my_B_size(jspin), dimen_RI)
! Additional integrals only for alpha_beta case and forces
IF ((my_alpha_beta_case) .AND. (calc_forces)) THEN
IF (my_alpha_beta_case .AND. calc_forces) THEN
local_ba(my_B_virtual_start(jspin):my_B_virtual_end(jspin), :) = &
TRANSPOSE(local_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), :))
ENDIF
! ... and from the other of my subgroup
DO proc_shift = 1, para_env_sub%num_pe - 1
proc_send = sub_proc_map(para_env_sub%mepos + proc_shift)
proc_receive = sub_proc_map(para_env_sub%mepos - proc_shift)

CALL get_group_dist(gd_B_virtual(jspin), proc_receive, rec_B_virtual_start, rec_B_virtual_end, rec_B_size)
CALL get_group_dist(gd_B_virtual(ispin), proc_receive, rec_B_virtual_start, rec_B_virtual_end, rec_B_size)

ALLOCATE (external_i_aL(dimen_RI, rec_B_size))
external_i_aL = 0.0_dp

CALL mp_sendrecv(my_local_j_aL, proc_send, &
CALL mp_sendrecv(my_local_i_aL, proc_send, &
external_i_aL, proc_receive, &
para_env_sub%group)

CALL dgemm_counter_start(dgemm_counter)
CALL dgemm('T', 'N', rec_B_size, my_B_size(ispin), dimen_RI, 1.0_dp, &
external_i_aL, dimen_RI, my_local_i_aL, dimen_RI, &
0.0_dp, local_ba(rec_B_virtual_start:rec_B_virtual_end, 1:my_B_size(ispin)), rec_B_size)
CALL dgemm_counter_stop(dgemm_counter, rec_B_size, my_B_size(ispin), dimen_RI)
CALL dgemm('T', 'N', rec_B_size, my_B_size(jspin), dimen_RI, 1.0_dp, &
external_i_aL, dimen_RI, my_local_j_aL, dimen_RI, &
0.0_dp, local_ab(rec_B_virtual_start:rec_B_virtual_end, 1:my_B_size(jspin)), rec_B_size)
CALL dgemm_counter_stop(dgemm_counter, rec_B_size, my_B_size(jspin), dimen_RI)

DEALLOCATE (external_i_aL)
END IF
! Additional integrals only for alpha_beta case and forces
IF ((my_alpha_beta_case) .AND. (calc_forces)) THEN

END DO
CALL timestop(handle3)

!sample peak memory
CALL m_memory()

CALL timeset(routineN//"_ener", handle3)
! calculate coulomb only MP2
sym_fac = 2.0_dp
IF (my_i == my_j) sym_fac = 1.0_dp
IF (my_alpha_beta_case) sym_fac = 0.5_dp
DO b = 1, my_B_size(jspin)
b_global = b + my_B_virtual_start(jspin) - 1
DO a = 1, virtual(ispin)
my_Emp2_Cou = my_Emp2_Cou - sym_fac*2.0_dp*local_ab(a, b)**2/ &
(Eigenval(homo(ispin) + a, ispin) + Eigenval(homo(jspin) + b_global, jspin) - &
Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, jspin))
END DO
END DO
CALL get_group_dist(gd_B_virtual(jspin), proc_receive, rec_B_virtual_start, &
rec_B_virtual_end, rec_B_size)

IF (calc_ex) THEN
! contract integrals with orbital energies for exchange MP2 energy
! starting with local ...
IF (calc_forces .AND. (.NOT. my_alpha_beta_case)) t_ab = 0.0_dp
DO b = 1, my_B_size(ispin)
b_global = b + my_B_virtual_start(ispin) - 1
DO a = 1, my_B_size(ispin)
a_global = a + my_B_virtual_start(ispin) - 1
my_Emp2_Ex = my_Emp2_Ex + sym_fac*local_ab(a_global, b)*local_ab(b_global, a)/ &
(Eigenval(homo(ispin) + a_global, ispin) + Eigenval(homo(ispin) + b_global, ispin) - &
Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, ispin))
IF (calc_forces .AND. (.NOT. my_alpha_beta_case)) &
t_ab(a_global, b) = -(amp_fac*local_ab(a_global, b) - mp2_env%scale_T*local_ab(b_global, a))/ &
(Eigenval(homo(ispin) + a_global, ispin) + &
Eigenval(homo(ispin) + b_global, ispin) - &
Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, ispin))
END DO
END DO
! ... and then with external data
DO proc_shift = 1, para_env_sub%num_pe - 1
proc_send = sub_proc_map(para_env_sub%mepos + proc_shift)
proc_receive = sub_proc_map(para_env_sub%mepos - proc_shift)
ALLOCATE (external_i_aL(dimen_RI, rec_B_size))
external_i_aL = 0.0_dp

CALL get_group_dist(gd_B_virtual(ispin), proc_receive, rec_B_virtual_start, rec_B_virtual_end, rec_B_size)
CALL get_group_dist(gd_B_virtual(ispin), proc_send, send_B_virtual_start, send_B_virtual_end, send_B_size)
CALL mp_sendrecv(my_local_j_aL, proc_send, &
external_i_aL, proc_receive, &
para_env_sub%group)

ALLOCATE (external_ab(my_B_size(ispin), rec_B_size))
external_ab = 0.0_dp
CALL dgemm_counter_start(dgemm_counter)
CALL dgemm('T', 'N', rec_B_size, my_B_size(ispin), dimen_RI, 1.0_dp, &
external_i_aL, dimen_RI, my_local_i_aL, dimen_RI, &
0.0_dp, local_ba(rec_B_virtual_start:rec_B_virtual_end, 1:my_B_size(ispin)), rec_B_size)
CALL dgemm_counter_stop(dgemm_counter, rec_B_size, my_B_size(ispin), dimen_RI)

CALL mp_sendrecv(local_ab(send_B_virtual_start:send_B_virtual_end, 1:my_B_size(ispin)), proc_send, &
external_ab(1:my_B_size(ispin), 1:rec_B_size), proc_receive, &
para_env_sub%group)
DEALLOCATE (external_i_aL)
ENDIF

END DO
CALL timestop(handle3)

!sample peak memory
CALL m_memory()

CALL timeset(routineN//"_ener", handle3)
! calculate coulomb only MP2
sym_fac = 2.0_dp
IF (my_i == my_j) sym_fac = 1.0_dp
IF (my_alpha_beta_case) sym_fac = 0.5_dp
DO b = 1, my_B_size(jspin)
b_global = b + my_B_virtual_start(jspin) - 1
DO a = 1, virtual(ispin)
my_Emp2_Cou = my_Emp2_Cou - sym_fac*2.0_dp*local_ab(a, b)**2/ &
(Eigenval(homo(ispin) + a, ispin) + Eigenval(homo(jspin) + b_global, jspin) - &
Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, jspin))
END DO
END DO

IF (calc_ex) THEN
! contract integrals with orbital energies for exchange MP2 energy
! starting with local ...
IF (calc_forces .AND. (.NOT. my_alpha_beta_case)) t_ab = 0.0_dp
DO b = 1, my_B_size(ispin)
b_global = b + my_B_virtual_start(ispin) - 1
DO a = 1, rec_B_size
a_global = a + rec_B_virtual_start - 1
my_Emp2_Ex = my_Emp2_Ex + sym_fac*local_ab(a_global, b)*external_ab(b, a)/ &
DO a = 1, my_B_size(ispin)
a_global = a + my_B_virtual_start(ispin) - 1
my_Emp2_Ex = my_Emp2_Ex + sym_fac*local_ab(a_global, b)*local_ab(b_global, a)/ &
(Eigenval(homo(ispin) + a_global, ispin) + Eigenval(homo(ispin) + b_global, ispin) - &
Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, ispin))
IF (calc_forces .AND. (.NOT. my_alpha_beta_case)) &
t_ab(a_global, b) = -(amp_fac*local_ab(a_global, b) - mp2_env%scale_T*external_ab(b, a))/ &
t_ab(a_global, b) = -(amp_fac*local_ab(a_global, b) - mp2_env%scale_T*local_ab(b_global, a))/ &
(Eigenval(homo(ispin) + a_global, ispin) + &
Eigenval(homo(ispin) + b_global, ispin) - &
Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, ispin))
END DO
END DO
! ... and then with external data
DO proc_shift = 1, para_env_sub%num_pe - 1
proc_send = sub_proc_map(para_env_sub%mepos + proc_shift)
proc_receive = sub_proc_map(para_env_sub%mepos - proc_shift)

DEALLOCATE (external_ab)
END DO
END IF
CALL timestop(handle3)

IF (calc_forces) THEN
! update P_ab, Gamma_P_ia
Y_i_aP = 0.0_dp
Y_j_aP = 0.0_dp
CALL mp2_update_P_gamma(mp2_env, para_env_sub, gd_B_virtual, &
Eigenval, homo, dimen_RI, iiB, jjB, my_B_size, &
my_B_virtual_end, my_B_virtual_start, my_i, my_j, virtual, &
sub_proc_map, local_ab, t_ab, my_local_i_aL, my_local_j_aL, &
my_open_shell_ss, Y_i_aP(:, :, iiB), Y_j_aP(:, :, jjB), local_ba, &
ispin, jspin, dgemm_counter)
CALL get_group_dist(gd_B_virtual(ispin), proc_receive, rec_B_virtual_start, rec_B_virtual_end, rec_B_size)
CALL get_group_dist(gd_B_virtual(ispin), proc_send, send_B_virtual_start, send_B_virtual_end, send_B_size)

END IF
ALLOCATE (external_ab(my_B_size(ispin), rec_B_size))
external_ab = 0.0_dp

CALL mp_sendrecv(local_ab(send_B_virtual_start:send_B_virtual_end, 1:my_B_size(ispin)), proc_send, &
external_ab(1:my_B_size(ispin), 1:rec_B_size), proc_receive, &
para_env_sub%group)

DO b = 1, my_B_size(ispin)
b_global = b + my_B_virtual_start(ispin) - 1
DO a = 1, rec_B_size
a_global = a + rec_B_virtual_start - 1
my_Emp2_Ex = my_Emp2_Ex + sym_fac*local_ab(a_global, b)*external_ab(b, a)/ &
(Eigenval(homo(ispin) + a_global, ispin) + Eigenval(homo(ispin) + b_global, ispin) - &
Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, ispin))
IF (calc_forces .AND. (.NOT. my_alpha_beta_case)) &
t_ab(a_global, b) = -(amp_fac*local_ab(a_global, b) - mp2_env%scale_T*external_ab(b, a))/ &
(Eigenval(homo(ispin) + a_global, ispin) + &
Eigenval(homo(ispin) + b_global, ispin) - &
Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, ispin))
END DO
END DO

DEALLOCATE (external_ab)
END DO
END IF
CALL timestop(handle3)

IF (calc_forces) THEN
! update P_ab, Gamma_P_ia
Y_i_aP = 0.0_dp
Y_j_aP = 0.0_dp
CALL mp2_update_P_gamma(mp2_env, para_env_sub, gd_B_virtual, &
Eigenval, homo, dimen_RI, iiB, jjB, my_B_size, &
my_B_virtual_end, my_B_virtual_start, my_i, my_j, virtual, &
sub_proc_map, local_ab, t_ab, my_local_i_aL, my_local_j_aL, &
my_open_shell_ss, Y_i_aP(:, :, iiB), Y_j_aP(:, :, jjB), local_ba, &
ispin, jspin, dgemm_counter)

END IF

END ASSOCIATE

END DO ! jjB
END DO ! iiB

DEALLOCATE (my_local_i_aL, my_local_j_aL)

ELSE
CALL timeset(routineN//"_comm", handle3)
! No work to do and we know that we have to receive nothing, but send something
Expand Down

0 comments on commit ebc1209

Please sign in to comment.