Skip to content

Commit

Permalink
RI HFX: adopt recent optimizations from low-scaling RPA
Browse files Browse the repository at this point in the history
  • Loading branch information
pseewald committed Mar 7, 2020
1 parent c715faa commit fc268b6
Show file tree
Hide file tree
Showing 7 changed files with 297 additions and 184 deletions.
115 changes: 95 additions & 20 deletions src/hfx_ri.F
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

MODULE hfx_ri

USE atomic_kind_types, ONLY: atomic_kind_type
USE basis_set_types, ONLY: gto_basis_set_p_type,&
gto_basis_set_type
USE cp_array_utils, ONLY: cp_1d_r_p_type
Expand All @@ -30,8 +31,7 @@ MODULE hfx_ri
dbcsr_add, dbcsr_add_on_diag, dbcsr_copy, dbcsr_create, dbcsr_distribution_get, &
dbcsr_distribution_release, dbcsr_distribution_type, dbcsr_dot, dbcsr_filter, &
dbcsr_frobenius_norm, dbcsr_get_info, dbcsr_multiply, dbcsr_p_type, dbcsr_release, &
dbcsr_scalar, dbcsr_scale, dbcsr_type, dbcsr_type_no_symmetry, dbcsr_type_real_8, &
dbcsr_type_symmetric
dbcsr_scalar, dbcsr_scale, dbcsr_type, dbcsr_type_no_symmetry, dbcsr_type_symmetric
USE dbcsr_tensor_api, ONLY: &
dbcsr_t_batched_contract_finalize, dbcsr_t_batched_contract_init, dbcsr_t_clear, &
dbcsr_t_contract, dbcsr_t_copy, dbcsr_t_copy_matrix_to_tensor, &
Expand All @@ -48,7 +48,8 @@ MODULE hfx_ri
matrix_sqrt_newton_schulz
USE kinds, ONLY: default_string_length,&
dp
USE message_passing, ONLY: mp_environ
USE message_passing, ONLY: mp_cart_create,&
mp_environ
USE particle_methods, ONLY: get_particle_set
USE particle_types, ONLY: particle_type
USE qs_environment_types, ONLY: get_qs_env,&
Expand Down Expand Up @@ -77,6 +78,9 @@ MODULE hfx_ri
USE qs_tensors_types, ONLY: contiguous_tensor_dist,&
create_2c_tensor,&
create_3c_tensor,&
create_tensor_batches,&
distribution_3d_create,&
distribution_3d_type,&
neighbor_list_3c_type,&
split_block_sizes
#include "./base/base_uses.f90"
Expand Down Expand Up @@ -324,12 +328,20 @@ SUBROUTINE hfx_ri_pre_scf_calc_tensors(qs_env, ri_data, t_2c_int_RI, t_2c_int_po
CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ri_pre_scf_calc_tensors', &
routineP = moduleN//':'//routineN

INTEGER :: handle, ibasis, natom, nkind
INTEGER, ALLOCATABLE, DIMENSION(:) :: sizes_AO, sizes_RI
INTEGER :: handle, i_mem, ibasis, mp_comm_t3c, &
n_mem, natom, nkind
INTEGER, ALLOCATABLE, DIMENSION(:) :: dist_AO_1, dist_AO_2, dist_RI, &
ends_array_mc_block_int, ends_array_mc_int, sizes_AO, sizes_RI, &
starts_array_mc_block_int, starts_array_mc_int
INTEGER, DIMENSION(3) :: pcoord, pdims
INTEGER, DIMENSION(:), POINTER :: col_bsize, row_bsize
TYPE(atomic_kind_type), DIMENSION(:), POINTER :: atomic_kind_set
TYPE(dbcsr_distribution_type) :: dbcsr_dist
TYPE(dbcsr_t_type) :: t_3c_tmp
TYPE(dbcsr_t_type), DIMENSION(1, 1) :: t_3c_int_batched
TYPE(dft_control_type), POINTER :: dft_control
TYPE(distribution_2d_type), POINTER :: dist_2d
TYPE(distribution_3d_type) :: dist_3d
TYPE(gto_basis_set_p_type), ALLOCATABLE, &
DIMENSION(:), TARGET :: basis_set_AO, basis_set_RI
TYPE(gto_basis_set_type), POINTER :: orb_basis
Expand Down Expand Up @@ -363,19 +375,61 @@ SUBROUTINE hfx_ri_pre_scf_calc_tensors(qs_env, ri_data, t_2c_int_RI, t_2c_int_po
CALL init_interaction_radii_orb_basis(orb_basis, ri_data%eps_pgf_orb)
ENDDO

CALL dbcsr_t_create(t_3c_int(1, 1), "(RI | AO AO)", ri_data%dist, [1], [2, 3], &
dbcsr_type_real_8, sizes_RI, sizes_AO, sizes_AO)
n_mem = ri_data%n_mem
CALL create_tensor_batches(sizes_AO, n_mem, starts_array_mc_int, ends_array_mc_int, &
starts_array_mc_block_int, ends_array_mc_block_int)
DEALLOCATE (starts_array_mc_int, ends_array_mc_int)

CALL build_3c_neighbor_lists(nl_3c, basis_set_RI, basis_set_AO, basis_set_AO, ri_data%dist_3d, ri_data%ri_metric, &
CALL create_3c_tensor(t_3c_int_batched(1, 1), dist_RI, dist_AO_1, dist_AO_2, ri_data%pgrid, &
sizes_RI, sizes_AO, sizes_AO, map1=[1], map2=[2, 3], &
starts_array_block_2=starts_array_mc_block_int, ends_array_block_2=ends_array_mc_block_int, &
name="(RI | AO AO)")

CALL get_qs_env(qs_env, nkind=nkind, particle_set=particle_set, atomic_kind_set=atomic_kind_set)
CALL dbcsr_t_mp_environ_pgrid(ri_data%pgrid, pdims, pcoord)
CALL mp_cart_create(ri_data%pgrid%mp_comm_2d, 3, pdims, pcoord, mp_comm_t3c)
CALL distribution_3d_create(dist_3d, dist_RI, dist_AO_1, dist_AO_2, &
nkind, particle_set, mp_comm_t3c, own_comm=.TRUE.)
DEALLOCATE (dist_RI, dist_AO_1, dist_AO_2)

CALL create_3c_tensor(t_3c_int(1, 1), dist_RI, dist_AO_1, dist_AO_2, ri_data%pgrid, &
ri_data%bsizes_RI_split, ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
map1=[1], map2=[2, 3], &
name="O (RI AO | AO)")

! create 3c tensor for storage of ints

CALL build_3c_neighbor_lists(nl_3c, basis_set_RI, basis_set_AO, basis_set_AO, dist_3d, ri_data%ri_metric, &
"HFX_3c_nl", &
qs_env, op_pos=1, sym_jk=.TRUE.)
qs_env, op_pos=1, sym_jk=.TRUE., own_dist=.TRUE.)

DO i_mem = 1, n_mem
CALL build_3c_integrals(t_3c_int_batched, ri_data%filter_eps_3c/2, qs_env, nl_3c, &
basis_set_RI, basis_set_AO, basis_set_AO, &
ri_data%ri_metric, int_eps=ri_data%eps_schwarz, op_pos=1, &
desymmetrize=.FALSE., &
bounds_j=[starts_array_mc_block_int(i_mem), ends_array_mc_block_int(i_mem)])
CALL dbcsr_t_copy(t_3c_int_batched(1, 1), t_3c_int(1, 1), summation=.TRUE., move_data=.TRUE.)
CALL dbcsr_t_clear(t_3c_int_batched(1, 1))
CALL dbcsr_t_filter(t_3c_int(1, 1), ri_data%filter_eps_3c/2)
ENDDO

CALL build_3c_integrals(t_3c_int, ri_data%filter_eps_3c, qs_env, nl_3c, basis_set_RI, basis_set_AO, basis_set_AO, &
ri_data%ri_metric, int_eps=ri_data%eps_schwarz, op_pos=1, &
desymmetrize=ri_data%flavor == ri_pmat)
CALL dbcsr_t_destroy(t_3c_int_batched(1, 1))

CALL neighbor_list_3c_destroy(nl_3c)

CALL dbcsr_t_create(t_3c_int(1, 1), t_3c_tmp)

IF (ri_data%flavor == ri_pmat) THEN ! desymmetrize
! desymmetrize
CALL dbcsr_t_copy(t_3c_int(1, 1), t_3c_tmp)
CALL dbcsr_t_copy(t_3c_tmp, t_3c_int(1, 1), order=[1, 3, 2], summation=.TRUE., move_data=.TRUE.)
CALL dbcsr_t_filter(t_3c_int(1, 1), ri_data%filter_eps_3c)

ENDIF

CALL dbcsr_t_destroy(t_3c_tmp)

CALL build_2c_neighbor_lists(nl_2c_pot, basis_set_RI, basis_set_RI, ri_data%hfx_pot, &
"HFX_2c_nl_pot", &
qs_env, sym_ij=.TRUE., &
Expand Down Expand Up @@ -510,20 +564,23 @@ SUBROUTINE hfx_ri_pre_scf_Pmat(qs_env, ri_data)
CALL dbcsr_t_create(t_2c_int_mat(1), t_2c_int(1), name="(RI|RI)")
CALL dbcsr_t_copy_matrix_to_tensor(t_2c_int_mat(1), t_2c_int(1))
CALL dbcsr_release(t_2c_int_mat(1))
CALL dbcsr_t_copy(t_2c_int(1), ri_data%t_2c_int(1, 1))
CALL dbcsr_t_destroy(t_2c_int(1))
CALL dbcsr_t_filter(ri_data%t_2c_int(1, 1), ri_data%filter_eps)

CALL timestop(handle2)
CALL timeset(routineN//"_3c", handle2)

CALL dbcsr_t_create(ri_data%t_3c_int_ctr_3(1, 1), t_3c_int_2(1, 1), name="(RI | AO AO)")
CALL dbcsr_t_contract(dbcsr_scalar(1.0_dp), t_2c_int(1), ri_data%t_3c_int_ctr_3(1, 1), &
CALL dbcsr_t_contract(dbcsr_scalar(1.0_dp), ri_data%t_2c_int(1, 1), ri_data%t_3c_int_ctr_3(1, 1), &
dbcsr_scalar(0.0_dp), t_3c_int_2(1, 1), &
contract_1=[2], notcontract_1=[1], &
contract_2=[1], notcontract_2=[2, 3], &
map_1=[1], map_2=[2, 3], filter_eps=ri_data%filter_eps, &
unit_nr=unit_nr_dbcsr, move_data=.TRUE., &
pgrid_opt_2=pgrid_opt)
CALL dbcsr_t_clear(ri_data%t_3c_int_ctr_3(1, 1))
CALL dbcsr_t_destroy(t_2c_int(1))
CALL dbcsr_t_clear(ri_data%t_2c_int(1, 1))

CPASSERT(ASSOCIATED(pgrid_opt))
IF (ASSOCIATED(ri_data%pgrid_3)) THEN
Expand Down Expand Up @@ -760,7 +817,7 @@ SUBROUTINE hfx_ri_update_ks_mo(qs_env, ri_data, ks_matrix, mo_coeff, &
DO ispin = 1, nspins
CALL dbcsr_get_info(mo_coeff(ispin), nfullcols_total=n_mos)

CALL split_block_sizes([n_mos], mo_bsizes, ri_data%max_bsize_mo)
CALL split_block_sizes([n_mos], mo_bsizes, ri_data%min_bsize_mo)

CALL mp_environ(nproc, iproc, para_env%group)

Expand Down Expand Up @@ -1037,8 +1094,9 @@ SUBROUTINE hfx_ri_update_ks_Pmat(qs_env, ri_data, ks_matrix, rho_ao, &
routineP = moduleN//':'//routineN

INTEGER :: handle, handle2, ispin, unit_nr_dbcsr
INTEGER, ALLOCATABLE, DIMENSION(:) :: dist1, dist2
TYPE(dbcsr_t_pgrid_type), POINTER :: pgrid_opt
TYPE(dbcsr_t_type) :: ks_t, rho_ao_t
TYPE(dbcsr_t_type) :: ks_t, ks_tmp, rho_ao_t, rho_ao_tmp
TYPE(dbcsr_t_type), DIMENSION(1, 1) :: t_3c_1, t_3c_2

CALL timeset(routineN, handle)
Expand All @@ -1062,11 +1120,23 @@ SUBROUTINE hfx_ri_update_ks_Pmat(qs_env, ri_data, ks_matrix, rho_ao, &

CALL dbcsr_t_create(ri_data%t_3c_int_ctr_2(1, 1), t_3c_1(1, 1))

CALL dbcsr_t_create(ks_matrix(1, 1)%matrix, ks_t)
CALL dbcsr_t_create(rho_ao(1, 1)%matrix, rho_ao_t)
CALL dbcsr_t_create(ks_matrix(1, 1)%matrix, ks_tmp)
CALL dbcsr_t_create(rho_ao(1, 1)%matrix, rho_ao_tmp)

CALL create_2c_tensor(rho_ao_t, dist1, dist2, ri_data%pgrid_2d, &
ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
name="(AO | AO)")
DEALLOCATE (dist1, dist2)

CALL create_2c_tensor(ks_t, dist1, dist2, ri_data%pgrid_2d, &
ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
name="(AO | AO)")
DEALLOCATE (dist1, dist2)

DO ispin = 1, nspins
CALL dbcsr_t_copy_matrix_to_tensor(rho_ao(ispin, 1)%matrix, rho_ao_t)
CALL dbcsr_t_copy_matrix_to_tensor(rho_ao(ispin, 1)%matrix, rho_ao_tmp)
CALL dbcsr_t_copy(rho_ao_tmp, rho_ao_t)
CALL dbcsr_t_clear(rho_ao_tmp)

CALL timeset(routineN//"_Px3C", handle2)
CALL dbcsr_t_contract(dbcsr_scalar(1.0_dp), rho_ao_t, ri_data%t_3c_int_ctr_2(1, 1), &
Expand All @@ -1076,6 +1146,7 @@ SUBROUTINE hfx_ri_update_ks_Pmat(qs_env, ri_data, ks_matrix, rho_ao, &
map_1=[1], map_2=[2, 3], filter_eps=ri_data%filter_eps, &
unit_nr=unit_nr_dbcsr, &
pgrid_opt_2=pgrid_opt)
CALL dbcsr_t_clear(rho_ao_t)

CALL timestop(handle2)

Expand Down Expand Up @@ -1107,8 +1178,10 @@ SUBROUTINE hfx_ri_update_ks_Pmat(qs_env, ri_data, ks_matrix, rho_ao, &

CALL dbcsr_t_destroy(t_3c_2(1, 1))

CALL dbcsr_t_copy_tensor_to_matrix(ks_t, ks_matrix(ispin, 1)%matrix, summation=.TRUE.)
CALL dbcsr_t_copy(ks_t, ks_tmp)
CALL dbcsr_t_clear(ks_t)
CALL dbcsr_t_copy_tensor_to_matrix(ks_tmp, ks_matrix(ispin, 1)%matrix, summation=.TRUE.)
CALL dbcsr_t_clear(ks_tmp)

CPASSERT(ASSOCIATED(pgrid_opt))
IF (ASSOCIATED(ri_data%pgrid_1)) THEN
Expand All @@ -1121,7 +1194,9 @@ SUBROUTINE hfx_ri_update_ks_Pmat(qs_env, ri_data, ks_matrix, rho_ao, &

CALL dbcsr_t_destroy(t_3c_1(1, 1))
CALL dbcsr_t_destroy(rho_ao_t)
CALL dbcsr_t_destroy(rho_ao_tmp)
CALL dbcsr_t_destroy(ks_t)
CALL dbcsr_t_destroy(ks_tmp)

CALL timestop(handle)

Expand Down

0 comments on commit fc268b6

Please sign in to comment.