Skip to content

Commit

Permalink
Low-scaling| minor optimization of RI-HFX/post-HF tensor code
Browse files Browse the repository at this point in the history
  • Loading branch information
abussy committed Jan 11, 2023
1 parent b3f583d commit 719a8cd
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 40 deletions.
13 changes: 7 additions & 6 deletions src/hfx_ri.F
Original file line number Diff line number Diff line change
Expand Up @@ -781,11 +781,11 @@ SUBROUTINE hfx_ri_pre_scf_Pmat(qs_env, ri_data)
DO i_mem = 1, ri_data%n_mem_RI
bounds_i(:, 1) = [ri_data%starts_array_RI_mem(i_mem), ri_data%ends_array_RI_mem(i_mem)]

CALL dbt_batched_contract_init(ri_data%t_2c_int(1, 1))
DO j_mem = 1, ri_data%n_mem
bounds_j(:, 1) = [ri_data%starts_array_mem(j_mem), ri_data%ends_array_mem(j_mem)]
bounds_j(:, 2) = [1, dims_3c(3)]
CALL timeset(routineN//"_RIx3C", handle2)
CALL dbt_batched_contract_init(ri_data%t_2c_int(1, 1))
CALL dbt_contract(1.0_dp, ri_data%t_2c_int(1, 1), ri_data%t_3c_int_ctr_3(1, 1), &
0.0_dp, t_3c_2, &
contract_1=[2], notcontract_1=[1], &
Expand All @@ -795,7 +795,6 @@ SUBROUTINE hfx_ri_pre_scf_Pmat(qs_env, ri_data)
bounds_3=bounds_j, &
unit_nr=unit_nr_dbcsr, &
flop=nflop)
CALL dbt_batched_contract_finalize(ri_data%t_2c_int(1, 1))

ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
CALL timestop(handle2)
Expand All @@ -811,6 +810,7 @@ SUBROUTINE hfx_ri_pre_scf_Pmat(qs_env, ri_data)

CALL timestop(handle2)
END DO
CALL dbt_batched_contract_finalize(ri_data%t_2c_int(1, 1))
END DO
CALL dbt_batched_contract_finalize(t_3c_2)
CALL dbt_batched_contract_finalize(ri_data%t_3c_int_ctr_3(1, 1))
Expand Down Expand Up @@ -2648,6 +2648,8 @@ SUBROUTINE hfx_ri_forces_Pmat(qs_env, ri_data, nspins, hf_fraction, rho_ao, rho_
CALL timeset(routineN//"_3c", handle)
!Start looping of the batches
CALL dbt_batched_contract_init(t_SVS)
CALL dbt_batched_contract_init(t_R)
DO i_mem = 1, n_mem
ibounds(:, 1) = [batch_start(i_mem), batch_end(i_mem)]
Expand Down Expand Up @@ -2689,30 +2691,27 @@ SUBROUTINE hfx_ri_forces_Pmat(qs_env, ri_data, nspins, hf_fraction, rho_ao, rho_
CALL dbt_copy(t_3c_sparse, t_3c_4, bounds=bounds_cpy)
!Contract with the 2-center product S^-1 * V * S^-1 while keeping sparsity of derivatives
CALL dbt_batched_contract_init(t_SVS)
CALL dbt_contract(1.0_dp, t_SVS, t_3c_3, 0.0_dp, t_3c_4, &
contract_1=[2], notcontract_1=[1], &
contract_2=[1], notcontract_2=[2, 3], &
map_1=[1], map_2=[2, 3], filter_eps=ri_data%filter_eps, &
retain_sparsity=.TRUE., unit_nr=unit_nr_dbcsr, flop=nflop)
ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
CALL dbt_batched_contract_finalize(t_SVS)
CALL dbt_copy(t_3c_4, t_3c_5, summation=.TRUE., move_data=.TRUE.)
ijbounds(:, 1) = ibounds(:, 1)
ijbounds(:, 2) = jbounds(:, 1)
!Contract R_PS = (acP) M_acS
CALL dbt_batched_contract_init(t_R)
CALL dbt_contract(1.0_dp, t_3c_int_2, t_3c_3, 1.0_dp, t_R, &
contract_1=[2, 3], notcontract_1=[1], &
contract_2=[2, 3], notcontract_2=[1], &
map_1=[1], map_2=[2], filter_eps=ri_data%filter_eps, &
bounds_1=ijbounds, bounds_3=kbounds, &
unit_nr=unit_nr_dbcsr, flop=nflop)
ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
CALL dbt_batched_contract_finalize(t_R)
END DO !k_mem
END DO !j_mem
Expand Down Expand Up @@ -2770,6 +2769,8 @@ SUBROUTINE hfx_ri_forces_Pmat(qs_env, ri_data, nspins, hf_fraction, rho_ao, rho_
CALL dbt_clear(t_3c_help_1)
CALL dbt_clear(t_3c_help_2)
END DO !i_mem
CALL dbt_batched_contract_finalize(t_SVS)
CALL dbt_batched_contract_finalize(t_R)
CALL timestop(handle)
CALL timeset(routineN//"_2c", handle)
Expand Down
46 changes: 15 additions & 31 deletions src/rpa_im_time_force_methods.F
Original file line number Diff line number Diff line change
Expand Up @@ -531,23 +531,19 @@ SUBROUTINE init_im_time_forces(force_data, fm_matrix_PQ, t_3c_M, unit_nr, mp2_en
!Create the rest of the 2-center AO tensors
nspins = dft_control%nspins
ALLOCATE (force_data%P_virt(nspins), force_data%P_occ(nspins))
ALLOCATE (force_data%sum_YP_tau(nspins), force_data%sum_O_tau(nspins), force_data%sum_ker_tau(nspins))
ALLOCATE (force_data%sum_YP_tau(nspins), force_data%sum_O_tau(nspins))
DO ispin = 1, nspins
ALLOCATE (force_data%P_virt(ispin)%matrix, force_data%P_occ(ispin)%matrix)
ALLOCATE (force_data%sum_YP_tau(ispin)%matrix, force_data%sum_O_tau(ispin)%matrix)
ALLOCATE (force_data%sum_ker_tau(ispin)%matrix)
CALL dbcsr_create(force_data%P_virt(ispin)%matrix, template=matrix_s(1)%matrix)
CALL dbcsr_create(force_data%P_occ(ispin)%matrix, template=matrix_s(1)%matrix)
CALL dbcsr_create(force_data%sum_O_tau(ispin)%matrix, template=matrix_s(1)%matrix)
CALL dbcsr_create(force_data%sum_ker_tau(ispin)%matrix, template=matrix_s(1)%matrix)
CALL dbcsr_create(force_data%sum_YP_tau(ispin)%matrix, template=matrix_s(1)%matrix)

CALL dbcsr_copy(force_data%sum_O_tau(ispin)%matrix, matrix_s(1)%matrix)
CALL dbcsr_copy(force_data%sum_ker_tau(ispin)%matrix, matrix_s(1)%matrix)
CALL dbcsr_copy(force_data%sum_YP_tau(ispin)%matrix, matrix_s(1)%matrix)

CALL dbcsr_set(force_data%sum_O_tau(ispin)%matrix, 0.0_dp)
CALL dbcsr_set(force_data%sum_ker_tau(ispin)%matrix, 0.0_dp)
CALL dbcsr_set(force_data%sum_YP_tau(ispin)%matrix, 0.0_dp)
END DO

Expand Down Expand Up @@ -953,12 +949,6 @@ SUBROUTINE calc_laplace_loop_forces(force_data, mat_P_omega, t_3c_M, t_3c_O, t_3
CALL dbcsr_multiply('N', 'N', pref*tau, matrix_ks(Pspin)%matrix, Y_2, 1.0_dp, &
force_data%sum_O_tau(Pspin)%matrix, retain_sparsity=.TRUE.)

!save the following to apply the kernel to it later
CALL dbcsr_multiply('N', 'N', fac*tau, Y_1, force_data%P_occ(Pspin)%matrix, 1.0_dp, &
force_data%sum_ker_tau(Pspin)%matrix, retain_sparsity=.TRUE.)
CALL dbcsr_multiply('N', 'N', -fac*tau, Y_2, force_data%P_virt(Pspin)%matrix, 1.0_dp, &
force_data%sum_ker_tau(Pspin)%matrix, retain_sparsity=.TRUE.)

CALL timestop(handle2)

IF (use_virial) THEN
Expand Down Expand Up @@ -1524,12 +1514,6 @@ SUBROUTINE calc_rpa_loop_forces(force_data, mat_P_omega, t_3c_M, t_3c_O, t_3c_O_
CALL dbcsr_multiply('N', 'N', tau*spin_fac, dbcsr_work_symm, Y_2, 1.0_dp, &
force_data%sum_O_tau(ispin)%matrix, retain_sparsity=.TRUE.)

!save the following to apply the kernel to it later
CALL dbcsr_multiply('N', 'N', fac*tau, Y_1, force_data%P_occ(ispin)%matrix, 1.0_dp, &
force_data%sum_ker_tau(ispin)%matrix, retain_sparsity=.TRUE.)
CALL dbcsr_multiply('N', 'N', -fac*tau, Y_2, force_data%P_virt(ispin)%matrix, 1.0_dp, &
force_data%sum_ker_tau(ispin)%matrix, retain_sparsity=.TRUE.)

CALL timestop(handle2)

IF (use_virial) THEN
Expand Down Expand Up @@ -1926,6 +1910,9 @@ SUBROUTINE perform_3c_ops(force, work_virial, t_R_occ, t_R_virt, force_data, fac
CALL dbt_copy(t_3c_M, t_3c_ints)
CALL timestop(handle2)

CALL dbt_batched_contract_init(t_R_occ)
CALL dbt_batched_contract_init(t_R_virt)
CALL dbt_batched_contract_init(t_KBKT)
DO i_mem = 1, cut_memory
ibounds(:, 1) = [starts_array_mc(i_mem), ends_array_mc(i_mem)]

Expand Down Expand Up @@ -1959,23 +1946,19 @@ SUBROUTINE perform_3c_ops(force, work_virial, t_R_occ, t_R_virt, force_data, fac
mp2_env%ri_rpa_im_time%eps_compress)
CALL dbt_copy(t_3c_M, t_3c_3, move_data=.TRUE.)

CALL dbt_batched_contract_init(t_R_occ)
CALL dbt_contract(1.0_dp, t_M_occ, t_3c_3, 1.0_dp, t_R_occ, &
contract_1=[1, 2], notcontract_1=[3], &
contract_2=[1, 2], notcontract_2=[3], &
map_1=[1], map_2=[2], filter_eps=eps_filter, &
flop=flop, unit_nr=unit_nr_dbcsr)
dbcsr_nflop = dbcsr_nflop + flop
CALL dbt_batched_contract_finalize(t_R_occ)

CALL dbt_batched_contract_init(t_R_virt)
CALL dbt_contract(1.0_dp, t_M_virt, t_3c_3, 1.0_dp, t_R_virt, &
contract_1=[1, 2], notcontract_1=[3], &
contract_2=[1, 2], notcontract_2=[3], &
map_1=[1], map_2=[2], filter_eps=eps_filter, &
flop=flop, unit_nr=unit_nr_dbcsr)
dbcsr_nflop = dbcsr_nflop + flop
CALL dbt_batched_contract_finalize(t_R_virt)
END DO
CALL dbt_copy(t_3c_M, t_3c_3)
CALL dbt_copy(t_3c_M, t_M_virt)
Expand All @@ -1991,6 +1974,7 @@ SUBROUTINE perform_3c_ops(force, work_virial, t_R_occ, t_R_virt, force_data, fac
bounds_cpy(:, 3) = [starts_array_mc(j_mem), ends_array_mc(j_mem)]
CALL dbt_copy(t_3c_sparse, t_3c_7, bounds=bounds_cpy)

CALL dbt_batched_contract_init(t_dm_virt)
DO k_mem = 1, n_mem_RI
bounds_2c(:, 1) = [batch_start_RI(k_mem), batch_end_RI(k_mem)]
bounds_2c(:, 2) = [starts_array_mc(i_mem), ends_array_mc(i_mem)]
Expand All @@ -1999,14 +1983,12 @@ SUBROUTINE perform_3c_ops(force, work_virial, t_R_occ, t_R_virt, force_data, fac

!Calculate (mu nu| P) * D_occ * D_virt
!Note: technically need M_occ*D_virt + M_virt*D_occ, but it is equivalent to 2*M_occ*D_virt
CALL dbt_batched_contract_init(t_dm_virt)
CALL dbt_contract(2.0_dp, t_3c_4, t_dm_virt, 0.0_dp, t_3c_5, &
contract_1=[3], notcontract_1=[1, 2], &
contract_2=[1], notcontract_2=[2], &
map_1=[1, 2], map_2=[3], filter_eps=eps_filter, &
bounds_2=bounds_2c, bounds_3=jbounds, flop=flop, unit_nr=unit_nr_dbcsr)
dbcsr_nflop = dbcsr_nflop + flop
CALL dbt_batched_contract_finalize(t_dm_virt)

CALL get_tensor_occupancy(t_3c_5, nze, occ)
nze_ddint = nze_ddint + nze
Expand All @@ -2017,18 +1999,17 @@ SUBROUTINE perform_3c_ops(force, work_virial, t_R_occ, t_R_virt, force_data, fac

!Calculate the contraction of the above with K*B*K^T
CALL timeset(routineN//"_3c_KBK", handle2)
CALL dbt_batched_contract_init(t_KBKT)
CALL dbt_contract(1.0_dp, t_KBKT, t_3c_6, 0.0_dp, t_3c_7, &
contract_1=[2], notcontract_1=[1], &
contract_2=[1], notcontract_2=[2, 3], &
map_1=[1], map_2=[2, 3], filter_eps=eps_filter, &
retain_sparsity=.TRUE., flop=flop, unit_nr=unit_nr_dbcsr)
dbcsr_nflop = dbcsr_nflop + flop
CALL dbt_batched_contract_finalize(t_KBKT)
CALL timestop(handle2)
CALL dbt_copy(t_3c_7, t_3c_8, summation=.TRUE.)

END DO !k_mem
CALL dbt_batched_contract_finalize(t_dm_virt)
END DO !j_mem

CALL dbt_copy(t_3c_8, t_3c_help_1, move_data=.TRUE.)
Expand Down Expand Up @@ -2083,6 +2064,9 @@ SUBROUTINE perform_3c_ops(force, work_virial, t_R_occ, t_R_virt, force_data, fac
END IF
CALL dbt_clear(t_3c_help_2)
END DO !i_mem
CALL dbt_batched_contract_finalize(t_KBKT)
CALL dbt_batched_contract_finalize(t_R_occ)
CALL dbt_batched_contract_finalize(t_R_virt)

DO k_mem = 1, n_mem_RI
DO i_mem = 1, cut_memory
Expand Down Expand Up @@ -2287,7 +2271,7 @@ END SUBROUTINE calc_post_loop_forces

! **************************************************************************************************
!> \brief Prepares the RHS of the z-vector equation. Apply the xc and HFX kernel on the previously
!> stored sum_ker_tau density, and add it to the final force_data%sum_O_tau quantity
!> stored sum_YP_tau density, and add it to the final force_data%sum_O_tau quantity
!> \param force_data ...
!> \param qs_env ...
! **************************************************************************************************
Expand Down Expand Up @@ -2334,7 +2318,7 @@ SUBROUTINE prepare_for_response(force_data, qs_env)
CALL dbcsr_set(dbcsr_p_work(ispin)%matrix, 0.0_dp)
END DO

!Apply the kernel on the density saved in force_data%sum_ker_tau
!Apply the kernel on the density saved in force_data%sum_YP_tau
ALLOCATE (rhoz_r(nspins), rhoz_g(nspins))
DO ispin = 1, nspins
CALL pw_pool_create_pw(auxbas_pw_pool, rhoz_r(ispin), &
Expand All @@ -2351,7 +2335,7 @@ SUBROUTINE prepare_for_response(force_data, qs_env)

CALL pw_zero(rhoz_tot_gspace)
DO ispin = 1, nspins
CALL calculate_rho_elec(ks_env=qs_env%ks_env, matrix_p=force_data%sum_ker_tau(ispin)%matrix, &
CALL calculate_rho_elec(ks_env=qs_env%ks_env, matrix_p=force_data%sum_YP_tau(ispin)%matrix, &
rho=rhoz_r(ispin), rho_gspace=rhoz_g(ispin))
CALL pw_axpy(rhoz_g(ispin), rhoz_tot_gspace)
END DO
Expand All @@ -2373,7 +2357,7 @@ SUBROUTINE prepare_for_response(force_data, qs_env)
CALL pw_pool_create_pw(auxbas_pw_pool, tauz_r(ispin), &
use_data=REALDATA3D, in_space=REALSPACE)

CALL calculate_rho_elec(ks_env=qs_env%ks_env, matrix_p=force_data%sum_ker_tau(ispin)%matrix, &
CALL calculate_rho_elec(ks_env=qs_env%ks_env, matrix_p=force_data%sum_YP_tau(ispin)%matrix, &
rho=tauz_r(ispin), rho_gspace=tauz_g, compute_tau=.TRUE.)
END DO
CALL pw_pool_give_back_pw(auxbas_pw_pool, tauz_g)
Expand Down Expand Up @@ -2445,7 +2429,7 @@ SUBROUTINE prepare_for_response(force_data, qs_env)
nao = admm_env%nao_orb
nao_aux = admm_env%nao_aux_fit
DO ispin = 1, nspins
CALL copy_dbcsr_to_fm(force_data%sum_ker_tau(ispin)%matrix, admm_env%work_orb_orb)
CALL copy_dbcsr_to_fm(force_data%sum_YP_tau(ispin)%matrix, admm_env%work_orb_orb)
CALL parallel_gemm('N', 'N', nao_aux, nao, nao, 1.0_dp, admm_env%A, admm_env%work_orb_orb, &
0.0_dp, admm_env%work_aux_orb)
CALL parallel_gemm('N', 'T', nao_aux, nao_aux, nao, 1.0_dp, admm_env%work_aux_orb, admm_env%A, &
Expand Down Expand Up @@ -2546,7 +2530,7 @@ SUBROUTINE prepare_for_response(force_data, qs_env)
CALL dbcsr_release(dbcsr_work)
CALL dbcsr_deallocate_matrix_set(ker_tau_admm)
ELSE
CALL tddft_hfx_matrix(dbcsr_p_work, force_data%sum_ker_tau, qs_env, .FALSE., .FALSE.)
CALL tddft_hfx_matrix(dbcsr_p_work, force_data%sum_YP_tau, qs_env, .FALSE., .FALSE.)
END IF
END IF

Expand Down
5 changes: 2 additions & 3 deletions src/rpa_im_time_force_types.F
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ MODULE rpa_im_time_force_types
!The occupied and virtual density matrices (standard block size, one for each spin)
TYPE(dbcsr_p_type), DIMENSION(:), POINTER :: P_occ, P_virt

!The weighted sum of the O(tau) matrices for thre response and associated kernel densities
TYPE(dbcsr_p_type), DIMENSION(:), POINTER :: sum_O_tau, sum_ker_tau
!The weighted sum of the O(tau) matrices for thre response
TYPE(dbcsr_p_type), DIMENSION(:), POINTER :: sum_O_tau

!The weigthed sum of the YP matrices for the trace with the Fockian derivative
TYPE(dbcsr_p_type), DIMENSION(:), POINTER :: sum_YP_tau
Expand Down Expand Up @@ -88,7 +88,6 @@ SUBROUTINE im_time_force_release(force_data)
CALL dbcsr_deallocate_matrix_set(force_data%P_virt)
CALL dbcsr_deallocate_matrix_set(force_data%P_occ)
CALL dbcsr_deallocate_matrix_set(force_data%sum_O_tau)
CALL dbcsr_deallocate_matrix_set(force_data%sum_ker_tau)
CALL dbcsr_deallocate_matrix_set(force_data%sum_YP_tau)

DO i_xyz = 1, 3
Expand Down

0 comments on commit 719a8cd

Please sign in to comment.