Skip to content

Commit

Permalink
RI HFX: impose sparsity of 3-centers integrals to contracted tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
pseewald committed Mar 7, 2020
1 parent 28fe9f5 commit 95c38f7
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 25 deletions.
234 changes: 219 additions & 15 deletions src/hfx_ri.F
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ MODULE hfx_ri
dbcsr_scalar, dbcsr_scale, dbcsr_type, dbcsr_type_no_symmetry, dbcsr_type_symmetric
USE dbcsr_tensor_api, ONLY: &
dbcsr_t_batched_contract_finalize, dbcsr_t_batched_contract_init, dbcsr_t_clear, &
dbcsr_t_contract, dbcsr_t_copy, dbcsr_t_copy_matrix_to_tensor, &
dbcsr_t_contract, dbcsr_t_contract_index, dbcsr_t_copy, dbcsr_t_copy_matrix_to_tensor, &
dbcsr_t_copy_tensor_to_matrix, dbcsr_t_create, dbcsr_t_destroy, dbcsr_t_filter, &
dbcsr_t_mp_environ_pgrid, dbcsr_t_nd_mp_comm, dbcsr_t_pgrid_change_dims, &
dbcsr_t_pgrid_create, dbcsr_t_pgrid_destroy, dbcsr_t_pgrid_type, dbcsr_t_type
dbcsr_t_pgrid_create, dbcsr_t_pgrid_destroy, dbcsr_t_pgrid_type, dbcsr_t_reserve_blocks, &
dbcsr_t_reserved_block_indices, dbcsr_t_type
USE distribution_2d_types, ONLY: distribution_2d_type
USE hfx_types, ONLY: hfx_ri_type
USE input_constants, ONLY: hfx_ri_do_2c_diag,&
Expand All @@ -48,7 +49,8 @@ MODULE hfx_ri
matrix_sqrt_newton_schulz
USE kinds, ONLY: default_string_length,&
dp
USE message_passing, ONLY: mp_cart_create,&
USE message_passing, ONLY: mp_allgather,&
mp_cart_create,&
mp_environ
USE particle_methods, ONLY: get_particle_set
USE particle_types, ONLY: particle_type
Expand Down Expand Up @@ -83,6 +85,7 @@ MODULE hfx_ri
distribution_3d_type,&
neighbor_list_3c_type,&
split_block_sizes
USE util, ONLY: sort
#include "./base/base_uses.f90"

IMPLICIT NONE
Expand Down Expand Up @@ -488,10 +491,19 @@ SUBROUTINE hfx_ri_pre_scf_Pmat(qs_env, ri_data)
CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ri_pre_scf_Pmat', &
routineP = moduleN//':'//routineN

INTEGER :: handle, handle2, unit_nr, unit_nr_dbcsr
INTEGER :: handle, handle2, iblk, is, nblk, &
nblks_total, nrows, row, unit_nr, &
unit_nr_dbcsr
INTEGER, ALLOCATABLE, DIMENSION(:) :: cols, cols_local, dist1, dist2, dist3, &
offsets_proc, rows, rows_local, &
sizes_proc
INTEGER, ALLOCATABLE, DIMENSION(:, :) :: blk_ind, blk_ind_2d
INTEGER, DIMENSION(3) :: pdims
REAL(KIND=dp) :: threshold
TYPE(cp_blacs_env_type), POINTER :: blacs_env
TYPE(cp_para_env_type), POINTER :: para_env
TYPE(dbcsr_t_pgrid_type) :: pgrid
TYPE(dbcsr_t_type) :: RI_AO_structure
TYPE(dbcsr_t_type), DIMENSION(1) :: t_2c_int
TYPE(dbcsr_t_type), DIMENSION(1, 1) :: t_3c_int_1
TYPE(dbcsr_type), DIMENSION(1) :: t_2c_int_mat, t_2c_op_pot, t_2c_op_RI, &
Expand Down Expand Up @@ -559,20 +571,162 @@ SUBROUTINE hfx_ri_pre_scf_Pmat(qs_env, ri_data)

CALL timestop(handle2)

! get sparsity of AO/RI-index
pdims = [0, 0, 1]
CALL dbcsr_t_pgrid_create(para_env%group, pdims, pgrid)
CALL create_3c_tensor(RI_AO_structure, dist1, dist2, dist3, pgrid, &
ri_data%bsizes_RI_split, ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
map1=[1], map2=[2, 3], &
name="O (RI | AO AO)")

CALL dbcsr_t_copy(ri_data%t_3c_int_ctr_1(1, 1), RI_AO_structure, order=[2, 1, 3])

CALL dbcsr_t_reserved_block_indices(RI_AO_structure, blk_ind)
CALL dbcsr_t_destroy(RI_AO_structure)

nblk = SIZE(blk_ind, 1)
ALLOCATE (blk_ind_2d(nblk, 2))
blk_ind_2d(:, :) = blk_ind(:, 1:2)
DEALLOCATE (blk_ind)
CALL sort_unique_blkind_2d(blk_ind_2d)
nblk = SIZE(blk_ind_2d, 1)

! merge blocks on all processes
ALLOCATE (sizes_proc(para_env%num_pe))
ALLOCATE (offsets_proc(para_env%num_pe))

CALL mp_allgather(nblk, sizes_proc, para_env%group)
nblks_total = SUM(sizes_proc)

offsets_proc(1) = 0
DO is = 2, SIZE(sizes_proc)
offsets_proc(is) = offsets_proc(is - 1) + sizes_proc(is - 1)
ENDDO

ALLOCATE (rows_local(nblk))
ALLOCATE (cols_local(nblk))

rows_local(:) = blk_ind_2d(:, 1)
cols_local(:) = blk_ind_2d(:, 2)

ALLOCATE (rows(nblks_total), cols(nblks_total))

CALL mp_allgather(rows_local, rows, sizes_proc, offsets_proc, para_env%group)
CALL mp_allgather(cols_local, cols, sizes_proc, offsets_proc, para_env%group)
DEALLOCATE (sizes_proc, offsets_proc, rows_local, cols_local)

ALLOCATE (ri_data%nonzero_pairs(nblks_total, 2))
ri_data%nonzero_pairs(:, 1) = rows
ri_data%nonzero_pairs(:, 2) = cols

DEALLOCATE (rows, cols)

CALL sort_unique_blkind_2d(ri_data%nonzero_pairs)

CPASSERT(nblks_total == SIZE(ri_data%nonzero_pairs, 1))

nrows = SIZE(ri_data%bsizes_RI_split)

ALLOCATE (ri_data%nonzero_rows(nrows + 1))

ASSOCIATE (rows=>ri_data%nonzero_pairs(:, 1))
iblk = 1
DO row = 1, nrows
DO WHILE (rows(iblk) < row)
iblk = iblk + 1
IF (iblk > nblks_total) EXIT
ENDDO
ri_data%nonzero_rows(row) = iblk
ENDDO
END ASSOCIATE

ri_data%nonzero_rows(nrows + 1) = nblks_total + 1

CALL timestop(handle)
END SUBROUTINE

! **************************************************************************************************
!> \brief Calculate Fock (AKA Kohn-Sham) matrix and calculate Hartree-Fock exchange energy based on RI expansion.
!> \brief Sorts 2d indices w.r.t. rows and columns
!> \param blk_ind ...
! **************************************************************************************************
SUBROUTINE sort_unique_blkind_2d(blk_ind)
INTEGER, ALLOCATABLE, DIMENSION(:, :), &
INTENT(INOUT) :: blk_ind

INTEGER :: end_ind, iblk, iblk_all, irow, nblk, &
ncols, start_ind
INTEGER, ALLOCATABLE, DIMENSION(:) :: ind_1, ind_2, sort_1, sort_2
INTEGER, ALLOCATABLE, DIMENSION(:, :) :: blk_ind_tmp

nblk = SIZE(blk_ind, 1)

ALLOCATE (sort_1(nblk))
ALLOCATE (ind_1(nblk))

sort_1(:) = blk_ind(:, 1)
CALL sort(sort_1, nblk, ind_1)

blk_ind(:, :) = blk_ind(ind_1, :)

start_ind = 1

DO WHILE (start_ind <= nblk)
irow = blk_ind(start_ind, 1)
end_ind = start_ind

IF (end_ind + 1 <= nblk) THEN
DO WHILE (blk_ind(end_ind + 1, 1) == irow)
end_ind = end_ind + 1
IF (end_ind + 1 > nblk) EXIT
ENDDO
ENDIF

ncols = end_ind - start_ind + 1
ALLOCATE (sort_2(ncols))
ALLOCATE (ind_2(ncols))
sort_2(:) = blk_ind(start_ind:end_ind, 2)
CALL sort(sort_2, ncols, ind_2)
ind_2 = ind_2 + start_ind - 1

blk_ind(start_ind:end_ind, :) = blk_ind(ind_2, :)
start_ind = end_ind + 1

DEALLOCATE (sort_2, ind_2)
ENDDO

ALLOCATE (blk_ind_tmp(nblk, 2))
blk_ind_tmp = 0

iblk = 0
DO iblk_all = 1, nblk
IF (iblk >= 1) THEN
IF (ALL(blk_ind_tmp(iblk, :) == blk_ind(iblk_all, :))) THEN
CYCLE
ENDIF
ENDIF
iblk = iblk + 1
blk_ind_tmp(iblk, :) = blk_ind(iblk_all, :)
ENDDO
nblk = iblk

DEALLOCATE (blk_ind)
ALLOCATE (blk_ind(nblk, 2))

blk_ind(:, :) = blk_ind_tmp(:nblk, :)

END SUBROUTINE

! **************************************************************************************************
!> \brief ...
!> \param qs_env ...
!> \param ri_data RI parameters and intermediate tensor data
!> \param ks_matrix Fock matrix
!> \param ehfx exchange energy
!> \param mos MO coefficients
!> \param rho_ao Density matrix
!> \param geometry_did_change flag that indicates we have to recalc integrals
!> \param nspins Number of spins
!> \param hf_fraction Fraction of exact exchange
!> \param ri_data ...
!> \param ks_matrix ...
!> \param ehfx ...
!> \param mos ...
!> \param rho_ao ...
!> \param geometry_did_change ...
!> \param nspins ...
!> \param hf_fraction ...
! **************************************************************************************************
SUBROUTINE hfx_ri_update_ks(qs_env, ri_data, ks_matrix, ehfx, mos, rho_ao, &
geometry_did_change, nspins, hf_fraction)
Expand Down Expand Up @@ -1055,8 +1209,12 @@ SUBROUTINE hfx_ri_update_ks_Pmat(qs_env, ri_data, ks_matrix, rho_ao, &
CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ri_update_ks_Pmat', &
routineP = moduleN//':'//routineN

INTEGER :: handle, handle2, ispin, unit_nr_dbcsr
INTEGER :: col, handle, handle2, iblk, iblk_filter, &
iblkrow, ispin, nblk, nblk_filter, &
row, row_end, row_start, unit_nr_dbcsr
INTEGER, ALLOCATABLE, DIMENSION(:) :: dist1, dist2
INTEGER, ALLOCATABLE, DIMENSION(:, :) :: ctr_ind, ctr_ind_tmp
LOGICAL :: found
TYPE(dbcsr_t_pgrid_type), POINTER :: pgrid_opt
TYPE(dbcsr_t_type) :: ks_t, ks_tmp, rho_ao_t, rho_ao_tmp, &
t_3c_1, t_3c_2
Expand Down Expand Up @@ -1119,20 +1277,66 @@ SUBROUTINE hfx_ri_update_ks_Pmat(qs_env, ri_data, ks_matrix, rho_ao, &
CALL dbcsr_t_clear(t_3c_1)
CALL timestop(handle2)

! impose sparsity of 3-center integrals:

CALL dbcsr_t_contract_index(dbcsr_scalar(1.0_dp), ri_data%t_2c_int(1, 1), ri_data%t_3c_int_ctr_3(1, 1), &
dbcsr_scalar(0.0_dp), t_3c_2, &
contract_1=[2], notcontract_1=[1], &
contract_2=[1], notcontract_2=[2, 3], &
map_1=[1], map_2=[2, 3], result_index=ctr_ind, filter_eps=ri_data%filter_eps)

nblk = SIZE(ctr_ind, 1)
ALLOCATE (ctr_ind_tmp(nblk, 3))
iblk_filter = 0
DO iblk = 1, nblk
row = ctr_ind(iblk, 1)
col = ctr_ind(iblk, 2)
found = .FALSE.

row_start = ri_data%nonzero_rows(row)
row_end = ri_data%nonzero_rows(row + 1)

IF (row_start /= row_end) THEN
DO iblkrow = row_start, row_end - 1
CPASSERT(ri_data%nonzero_pairs(iblkrow, 1) == row)
IF (ri_data%nonzero_pairs(iblkrow, 2) == col) found = .TRUE.
ENDDO
ENDIF

IF (found) THEN
iblk_filter = iblk_filter + 1
ctr_ind_tmp(iblk_filter, :) = ctr_ind(iblk, :)
ENDIF

ENDDO

nblk_filter = iblk_filter

DEALLOCATE (ctr_ind)
ALLOCATE (ctr_ind(nblk_filter, 3))
ctr_ind(:, :) = ctr_ind_tmp(1:nblk_filter, :)
DEALLOCATE (ctr_ind_tmp)

CALL dbcsr_t_reserve_blocks(t_3c_2, ctr_ind)
DEALLOCATE (ctr_ind)

CALL timeset(routineN//"_RIx3C", handle2)
CALL dbcsr_t_contract(dbcsr_scalar(1.0_dp), ri_data%t_2c_int(1, 1), ri_data%t_3c_int_ctr_3(1, 1), &
dbcsr_scalar(0.0_dp), t_3c_2, &
contract_1=[2], notcontract_1=[1], &
contract_2=[1], notcontract_2=[2, 3], &
map_1=[1], map_2=[2, 3], filter_eps=ri_data%filter_eps, &
unit_nr=unit_nr_dbcsr, &
retain_sparsity=.TRUE., &
pgrid_opt_2=pgrid_opt)
CALL dbcsr_t_clear(ri_data%t_3c_int_ctr_3(1, 1))

CALL dbcsr_t_filter(t_3c_2, ri_data%filter_eps)

CALL timestop(handle2)

CALL timeset(routineN//"_copy_2", handle2)
CALL dbcsr_t_copy(t_3c_2, t_3c_1, order=[2, 1, 3], move_data=.TRUE.)
CALL dbcsr_t_clear(t_3c_2)
CALL timestop(handle2)

CPASSERT(ASSOCIATED(pgrid_opt))
Expand Down
24 changes: 14 additions & 10 deletions src/hfx_types.F
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,10 @@ MODULE hfx_types

! memory reduction factor
INTEGER :: n_mem

! relevant non-zero RI-AO pairs
INTEGER, DIMENSION(:, :), ALLOCATABLE :: nonzero_pairs
INTEGER, DIMENSION(:), ALLOCATABLE :: nonzero_rows
END TYPE

! **************************************************************************************************
Expand Down Expand Up @@ -1283,23 +1287,17 @@ SUBROUTINE hfx_ri_init(ri_data, qs_kind_set, particle_set, atomic_kind_set, para
CALL split_block_sizes(ri_data%bsizes_RI, ri_data%bsizes_RI_fit, libcusmm_maxblocksize)

IF (ri_data%flavor == ri_pmat) THEN
ALLOCATE (ri_data%t_2c_int(1, 1))
ALLOCATE (ri_data%t_3c_int_ctr_1(1, 1))

CALL create_3c_tensor(ri_data%t_3c_int_ctr_1(1, 1), ri_data%dist1_ao_1, ri_data%dist1_ri, ri_data%dist1_ao_2, &
ri_data%pgrid, ri_data%bsizes_AO_split, ri_data%bsizes_RI_split, ri_data%bsizes_AO_split, &
[1, 2], [3], name="(AO RI | AO)")

ALLOCATE (ri_data%t_3c_int_ctr_2(1, 1))
CALL create_3c_tensor(ri_data%t_3c_int_ctr_2(1, 1), ri_data%dist2_ao_1, ri_data%dist2_ri, ri_data%dist2_ao_2, &
ri_data%pgrid, ri_data%bsizes_AO_split, ri_data%bsizes_RI_split, ri_data%bsizes_AO_split, &
[1], [2, 3], name="(AO | RI AO)")

ALLOCATE (ri_data%t_3c_int_ctr_3(1, 1))
CALL create_3c_tensor(ri_data%t_3c_int_ctr_3(1, 1), ri_data%dist3_RI, ri_data%dist3_AO_1, ri_data%dist3_AO_2, &
ri_data%pgrid, ri_data%bsizes_RI_split, ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
[1], [2, 3], name="(RI | AO AO)")

ALLOCATE (ri_data%t_2c_int(1, 1))
CALL create_2c_tensor(ri_data%t_2c_int(1, 1), dist1, dist2, ri_data%pgrid_2d, &
ri_data%bsizes_RI_split, ri_data%bsizes_RI_split, &
name="(RI | RI)")
Expand Down Expand Up @@ -1388,18 +1386,24 @@ SUBROUTINE hfx_ri_release(ri_data)
CALL distribution_3d_destroy(ri_data%dist_3d)
CALL dbcsr_t_distribution_destroy(ri_data%dist)

DEALLOCATE (ri_data%bsizes_RI)
DEALLOCATE (ri_data%bsizes_AO)
DEALLOCATE (ri_data%bsizes_AO_split)
DEALLOCATE (ri_data%bsizes_RI_split)
DEALLOCATE (ri_data%bsizes_AO_fit)
DEALLOCATE (ri_data%bsizes_RI_fit)

IF (ri_data%flavor == ri_pmat) THEN
CALL dbcsr_t_destroy(ri_data%t_3c_int_ctr_1(1, 1))
DEALLOCATE (ri_data%t_3c_int_ctr_1)
DEALLOCATE (ri_data%dist1_RI, ri_data%dist1_AO_1, ri_data%dist1_AO_2)
DEALLOCATE (ri_data%dist2_RI, ri_data%dist2_AO_1, ri_data%dist2_AO_2)
DEALLOCATE (ri_data%dist3_RI, ri_data%dist3_AO_1, ri_data%dist3_AO_2)
CALL dbcsr_t_destroy(ri_data%t_3c_int_ctr_2(1, 1))
DEALLOCATE (ri_data%t_3c_int_ctr_2)
CALL dbcsr_t_destroy(ri_data%t_3c_int_ctr_3(1, 1))
DEALLOCATE (ri_data%t_3c_int_ctr_3)
CALL dbcsr_t_destroy(ri_data%t_2c_int(1, 1))
DEALLOCATE (ri_data%t_2c_int)
DEALLOCATE (ri_data%nonzero_pairs)
DEALLOCATE (ri_data%nonzero_rows)
ELSEIF (ri_data%flavor == ri_mo) THEN
CALL dbcsr_t_destroy(ri_data%t_3c_int_ctr_1(1, 1))
CALL dbcsr_t_destroy(ri_data%t_3c_int_ctr_2(1, 1))
Expand Down

0 comments on commit 95c38f7

Please sign in to comment.