Skip to content

Commit

Permalink
RI HFX: better load balancing for RI-MO variant
Browse files Browse the repository at this point in the history
- store upper and lower triangular part of integral tensor independently
- adapt load balancing optimizations from RPA
  • Loading branch information
pseewald committed Apr 29, 2020
1 parent 2b878fb commit c9a73d8
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 28 deletions.
106 changes: 80 additions & 26 deletions src/hfx_ri.F
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,9 @@ SUBROUTINE hfx_ri_pre_scf_mo(qs_env, ri_data, nspins)

CALL timeset(routineN//"_3c", handle2)
CALL dbcsr_t_copy(ri_data%t_3c_int(1, 1), ri_data%t_3c_int_ctr_1(1, 1), order=[2, 1, 3], move_data=.TRUE.)
CALL dbcsr_t_filter(ri_data%t_3c_int_ctr_1(1, 1), ri_data%filter_eps)
CALL dbcsr_t_destroy(ri_data%t_3c_int(1, 1))
CALL dbcsr_t_filter(ri_data%t_3c_int_ctr_1(1, 1), ri_data%filter_eps)
CALL dbcsr_t_copy(ri_data%t_3c_int_ctr_1(1, 1), ri_data%t_3c_int_ctr_2(1, 1))
DEALLOCATE (ri_data%t_3c_int)
CALL timestop(handle2)

Expand Down Expand Up @@ -956,13 +957,13 @@ SUBROUTINE hfx_ri_update_ks_mo(qs_env, ri_data, ks_matrix, mo_coeff, &
routineP = moduleN//':'//routineN
REAL(dp), PARAMETER :: lb_ratio = 0.1_dp

INTEGER :: bsize, comm_2d, count, handle, handle2, i_mem, iproc, ispin, max_pdim_mo, n_mem, &
n_mos, nproc, nproc_rem, pdim_AO, pdim_AO_, pdim_AO__, pdim_mo, tdim_mo, unit_nr_dbcsr
INTEGER :: bsize, bsum, comm_2d, count, handle, handle2, i_mem, iblock, iproc, ispin, &
max_pdim_mo, n_mem, n_mos, nblock, nproc, nproc_rem, pdim_AO, pdim_AO_, pdim_AO__, &
pdim_mo, tdim_mo, unit_nr_dbcsr
INTEGER(int_8) :: nblks, nflop
INTEGER, ALLOCATABLE, DIMENSION(:) :: dist1, dist2, dist3, mem_end, &
mem_end_block, mem_size, mem_start, &
mem_start_block, mo_bsizes_1, &
mo_bsizes_2
INTEGER, ALLOCATABLE, DIMENSION(:) :: dist1, dist2, dist3, mem_end, mem_end_block_1, &
mem_end_block_2, mem_size, mem_start, mem_start_block_1, mem_start_block_2, mo_bsizes_1, &
mo_bsizes_2
INTEGER, ALLOCATABLE, DIMENSION(:, :) :: bounds
INTEGER, DIMENSION(2) :: pdims_2d
INTEGER, DIMENSION(3) :: pcoord, pdims, pdims_AO, pdims_RI
Expand All @@ -971,7 +972,8 @@ SUBROUTINE hfx_ri_update_ks_mo(qs_env, ri_data, ks_matrix, mo_coeff, &
TYPE(cp_para_env_type), POINTER :: para_env
TYPE(dbcsr_distribution_type) :: ks_dist
TYPE(dbcsr_t_pgrid_type) :: pgrid, pgrid_2d
TYPE(dbcsr_t_pgrid_type), POINTER :: pgrid_opt_KS, pgrid_opt_RI
TYPE(dbcsr_t_pgrid_type), POINTER :: pgrid_opt_KS, pgrid_opt_mo_L, &
pgrid_opt_mo_R, pgrid_opt_RI
TYPE(dbcsr_t_type) :: ks_t, ks_t_mat, mo_coeff_t, &
mo_coeff_t_split
TYPE(dbcsr_t_type), DIMENSION(1, 1) :: t_3c_int_mo_1, t_3c_int_mo_2
Expand All @@ -980,7 +982,7 @@ SUBROUTINE hfx_ri_update_ks_mo(qs_env, ri_data, ks_matrix, mo_coeff, &

CPASSERT(SIZE(ks_matrix, 2) == 1)

NULLIFY (pgrid_opt_RI, pgrid_opt_KS)
NULLIFY (pgrid_opt_RI, pgrid_opt_KS, pgrid_opt_mo_L, pgrid_opt_mo_R)

unit_nr_dbcsr = ri_data%unit_nr_dbcsr

Expand Down Expand Up @@ -1040,26 +1042,51 @@ SUBROUTINE hfx_ri_update_ks_mo(qs_env, ri_data, ks_matrix, mo_coeff, &
mo_bsizes_2 = 1

CALL create_tensor_batches(mo_bsizes_2, n_mem, mem_start, mem_end, &
mem_start_block, mem_end_block)
mem_start_block_2, mem_end_block_2)

DO i_mem = 1, n_mem
bsize = SUM(mo_bsizes_2(mem_start_block(i_mem):mem_end_block(i_mem)))
bsize = SUM(mo_bsizes_2(mem_start_block_2(i_mem):mem_end_block_2(i_mem)))
mem_size(i_mem) = bsize
ENDDO

CALL split_block_sizes(mem_size, mo_bsizes_1, ri_data%min_bsize_MO)
ALLOCATE (mem_start_block_1(n_mem))
ALLOCATE (mem_end_block_1(n_mem))
nblock = SIZE(mo_bsizes_1)
iblock = 0
DO i_mem = 1, n_mem
bsum = 0
DO
iblock = iblock + 1
CPASSERT(iblock <= nblock)
bsum = bsum + mo_bsizes_1(iblock)
IF (bsum == mem_size(i_mem)) THEN
IF (i_mem == 1) THEN
mem_start_block_1(i_mem) = 1
ELSE
mem_start_block_1(i_mem) = mem_end_block_1(i_mem - 1) + 1
ENDIF
mem_end_block_1(i_mem) = iblock
EXIT
ENDIF
ENDDO
ENDDO

CALL mp_environ(nproc, iproc, para_env%group)

CALL create_3c_tensor(t_3c_int_mo_1(1, 1), dist1, dist2, dist3, ri_data%pgrid_1, &
ri_data%bsizes_AO_split, ri_data%bsizes_RI_split, mo_bsizes_1, &
[1, 2], [3], name="(AO RI | MO)")
[1, 2], [3], &
starts_array_block_3=mem_start_block_1, ends_array_block_3=mem_end_block_1, &
name="(AO RI | MO)")

DEALLOCATE (dist1, dist2, dist3)

CALL create_3c_tensor(t_3c_int_mo_2(1, 1), dist1, dist2, dist3, ri_data%pgrid_2, &
mo_bsizes_1, ri_data%bsizes_RI_split, ri_data%bsizes_AO_split, &
[1], [2, 3], name="(MO | RI AO)")
[1], [2, 3], &
starts_array_block_1=mem_start_block_1, ends_array_block_1=mem_end_block_1, &
name="(MO | RI AO)")

DEALLOCATE (dist1, dist2, dist3)

Expand All @@ -1080,15 +1107,15 @@ SUBROUTINE hfx_ri_update_ks_mo(qs_env, ri_data, ks_matrix, mo_coeff, &
CALL create_3c_tensor(ri_data%t_3c_int_mo(ispin, 1, 1), dist1, dist2, dist3, pgrid, &
ri_data%bsizes_RI_fit, mo_bsizes_2, ri_data%bsizes_AO_fit, &
[1], [2, 3], &
starts_array_block_2=mem_start_block, ends_array_block_2=mem_end_block, &
starts_array_block_2=mem_start_block_2, ends_array_block_2=mem_end_block_2, &
name="(RI | MO AO)")

DEALLOCATE (dist1, dist2, dist3)

CALL create_3c_tensor(ri_data%t_3c_ctr_KS(ispin, 1, 1), dist1, dist2, dist3, pgrid, &
ri_data%bsizes_RI_fit, mo_bsizes_2, ri_data%bsizes_AO_fit, &
[1, 2], [3], &
starts_array_block_2=mem_start_block, ends_array_block_2=mem_end_block, &
starts_array_block_2=mem_start_block_2, ends_array_block_2=mem_end_block_2, &
name="(RI MO | AO)")
DEALLOCATE (dist1, dist2, dist3)
CALL dbcsr_t_pgrid_destroy(pgrid)
Expand Down Expand Up @@ -1137,10 +1164,23 @@ SUBROUTINE hfx_ri_update_ks_mo(qs_env, ri_data, ks_matrix, mo_coeff, &
pdims = [pdims_RI(1), nproc_rem/pdim_AO, pdim_AO]

CALL dbcsr_t_pgrid_change_dims(pgrid_opt_RI, pdims)
CALL tensor_change_pgrid(ri_data%t_3c_int_mo(ispin, 1, 1), pgrid_opt_RI, nodata=.TRUE., unit_nr=unit_nr_dbcsr)
CALL tensor_change_pgrid(ri_data%t_3c_ctr_RI(ispin, 1, 1), pgrid_opt_RI, nodata=.TRUE.)
CALL tensor_change_pgrid(ri_data%t_3c_ctr_KS(ispin, 1, 1), pgrid_opt_RI, nodata=.TRUE.)
CALL tensor_change_pgrid(ri_data%t_3c_ctr_KS_copy(ispin, 1, 1), pgrid_opt_RI, nodata=.TRUE.)
CALL tensor_change_pgrid(ri_data%t_3c_int_mo(ispin, 1, 1), pgrid_opt_RI, nodata=.TRUE., &
starts_array_mc_block_2=mem_start_block_2, ends_array_mc_block_2=mem_end_block_2, &
unit_nr=unit_nr_dbcsr)
CALL tensor_change_pgrid(ri_data%t_3c_ctr_RI(ispin, 1, 1), pgrid_opt_RI, nodata=.TRUE., &
starts_array_mc_block_2=mem_start_block_2, ends_array_mc_block_2=mem_end_block_2)
CALL tensor_change_pgrid(ri_data%t_3c_ctr_KS(ispin, 1, 1), pgrid_opt_RI, nodata=.TRUE., &
starts_array_mc_block_2=mem_start_block_2, ends_array_mc_block_2=mem_end_block_2)
CALL tensor_change_pgrid(ri_data%t_3c_ctr_KS_copy(ispin, 1, 1), pgrid_opt_RI, nodata=.TRUE., &
starts_array_mc_block_2=mem_start_block_2, ends_array_mc_block_2=mem_end_block_2)
CALL tensor_change_pgrid(t_3c_int_mo_1(1, 1), pgrid_opt_mo_R, nodata=.TRUE., &
starts_array_mc_block_3=mem_start_block_1, ends_array_mc_block_3=mem_end_block_1, &
unit_nr=unit_nr_dbcsr)
CALL tensor_change_pgrid(ri_data%t_3c_int_ctr_1(1, 1), pgrid_opt_mo_R)
CALL tensor_change_pgrid(t_3c_int_mo_2(1, 1), pgrid_opt_mo_L, nodata=.TRUE., &
starts_array_mc_block_1=mem_start_block_1, ends_array_mc_block_1=mem_end_block_1, &
unit_nr=unit_nr_dbcsr)
CALL tensor_change_pgrid(ri_data%t_3c_int_ctr_2(1, 1), pgrid_opt_mo_L)

! initialize batched contraction here since process grid have changed
CALL dbcsr_t_batched_contract_init(ri_data%t_2c_int(ispin, 1))
Expand All @@ -1156,13 +1196,19 @@ SUBROUTINE hfx_ri_update_ks_mo(qs_env, ri_data, ks_matrix, mo_coeff, &

bounds(:, 1) = [mem_start(i_mem), mem_end(i_mem)]

IF (ASSOCIATED(pgrid_opt_mo_R)) THEN
CALL dbcsr_t_pgrid_destroy(pgrid_opt_mo_R)
DEALLOCATE (pgrid_opt_mo_R)
ENDIF

CALL timeset(routineN//"_MOx3C_R", handle2)
CALL dbcsr_t_contract(dbcsr_scalar(1.0_dp), mo_coeff_t_split, ri_data%t_3c_int_ctr_1(1, 1), &
dbcsr_scalar(0.0_dp), t_3c_int_mo_1(1, 1), &
contract_1=[1], notcontract_1=[2], &
contract_2=[3], notcontract_2=[1, 2], &
map_1=[3], map_2=[1, 2], &
bounds_2=bounds, &
pgrid_opt_3=pgrid_opt_mo_R, &
filter_eps=ri_data%filter_eps_mo/2, &
unit_nr=unit_nr_dbcsr, &
move_data=.FALSE., &
Expand All @@ -1175,9 +1221,10 @@ SUBROUTINE hfx_ri_update_ks_mo(qs_env, ri_data, ks_matrix, mo_coeff, &
CALL dbcsr_t_copy(t_3c_int_mo_1(1, 1), ri_data%t_3c_int_mo(ispin, 1, 1), order=[3, 1, 2], move_data=.TRUE.)
CALL timestop(handle2)

CALL timeset(routineN//"_copy_symm", handle2)
CALL dbcsr_t_copy(ri_data%t_3c_int_ctr_1(1, 1), ri_data%t_3c_int_ctr_2(1, 1), move_data=.TRUE.)
CALL timestop(handle2)
IF (ASSOCIATED(pgrid_opt_mo_L)) THEN
CALL dbcsr_t_pgrid_destroy(pgrid_opt_mo_L)
DEALLOCATE (pgrid_opt_mo_L)
ENDIF

CALL timeset(routineN//"_MOx3C_L", handle2)
CALL dbcsr_t_contract(dbcsr_scalar(1.0_dp), mo_coeff_t_split, ri_data%t_3c_int_ctr_2(1, 1), &
Expand All @@ -1186,6 +1233,7 @@ SUBROUTINE hfx_ri_update_ks_mo(qs_env, ri_data, ks_matrix, mo_coeff, &
contract_2=[1], notcontract_2=[2, 3], &
map_1=[1], map_2=[2, 3], &
bounds_2=bounds, &
pgrid_opt_3=pgrid_opt_mo_L, &
filter_eps=ri_data%filter_eps_mo/2, &
unit_nr=unit_nr_dbcsr, &
move_data=.FALSE., &
Expand All @@ -1201,10 +1249,6 @@ SUBROUTINE hfx_ri_update_ks_mo(qs_env, ri_data, ks_matrix, mo_coeff, &
CALL dbcsr_t_filter(ri_data%t_3c_int_mo(ispin, 1, 1), ri_data%filter_eps_mo)
CALL timestop(handle2)

CALL timeset(routineN//"_copy_symm", handle2)
CALL dbcsr_t_copy(ri_data%t_3c_int_ctr_2(1, 1), ri_data%t_3c_int_ctr_1(1, 1), move_data=.TRUE.)
CALL timestop(handle2)

IF (ASSOCIATED(pgrid_opt_RI)) THEN
CALL dbcsr_t_pgrid_destroy(pgrid_opt_RI)
DEALLOCATE (pgrid_opt_RI)
Expand Down Expand Up @@ -1252,6 +1296,16 @@ SUBROUTINE hfx_ri_update_ks_mo(qs_env, ri_data, ks_matrix, mo_coeff, &
CALL timestop(handle2)
ENDDO

IF (ASSOCIATED(pgrid_opt_mo_R)) THEN
CALL dbcsr_t_pgrid_destroy(pgrid_opt_mo_R)
DEALLOCATE (pgrid_opt_mo_R)
ENDIF

IF (ASSOCIATED(pgrid_opt_mo_L)) THEN
CALL dbcsr_t_pgrid_destroy(pgrid_opt_mo_L)
DEALLOCATE (pgrid_opt_mo_L)
ENDIF

IF (ASSOCIATED(pgrid_opt_RI)) THEN
CALL dbcsr_t_pgrid_destroy(pgrid_opt_RI)
DEALLOCATE (pgrid_opt_RI)
Expand Down
7 changes: 5 additions & 2 deletions src/hfx_types.F
Original file line number Diff line number Diff line change
Expand Up @@ -1339,11 +1339,14 @@ SUBROUTINE hfx_ri_init(ri_data, qs_kind_set, particle_set, atomic_kind_set, para
! is larger than this (it is however not a problem for load balancing if actual MO dimension
! is slightly smaller)
MO_dim = MAX((ri_data%nelectron_total/2 - 1)/ri_data%n_mem + 1, 1)
MO_dim = (MO_dim - 1)/ri_data%min_bsize_MO + 1

pdims = [1, nproc, 1]
CALL dbcsr_t_mp_dims_create(nproc, pdims, [MO_dim, SIZE(ri_data%bsizes_RI_split), MO_dim])
pdims = 0
CALL dbcsr_t_mp_dims_create(nproc, pdims, [SIZE(ri_data%bsizes_AO_split), SIZE(ri_data%bsizes_RI_split), MO_dim])

CALL dbcsr_t_pgrid_create(para_env%group, pdims, ri_data%pgrid_1)

pdims = pdims([3, 2, 1])
CALL dbcsr_t_pgrid_create(para_env%group, pdims, ri_data%pgrid_2)

CALL create_3c_tensor(ri_data%t_3c_int_ctr_1(1, 1), ri_data%dist1_ao_1, ri_data%dist1_ri, ri_data%dist1_ao_2, &
Expand Down

0 comments on commit c9a73d8

Please sign in to comment.