Skip to content

Commit

Permalink
Active space: Implement ERI calculation using half-transformed integr…
Browse files Browse the repository at this point in the history
…als (#3082)



This PR implements several options which have already been available in the input but were not implemented:

    half-transformed integral calculation algorithm: borrowed from the GPW-MP2 method. It is a much fast algorithm to calculate the integrals, especially if all orbitals are correlated. This implementation is 2-3 faster than the legacy FULL_GPW algorithm.
    store_wfn: this option enables the on-the-fly calculation the orbital functions on a grid. It increases the costs by 40 % in case of half-transformed integrals (no idea in case of the FULL_GPW method but probably much more expensive than the with storing the orbitals). It allows to reduce the memory footprint significantly, especially in case of molecular systems with their larger number of grid points
    subgroups: As in the WFC module, the calculation of integrals is possible in subgroups trading memory for reduced communication. I measured a threefold improvement.
    The new implementation is GPU accelerated as it relies on DBCSR and the grid code for the integration of the potential (untested).
    A bunch of new regression tests (QS/regtest-as-2, QS/regtest-as-3)

This implementation works well with the GPW method but is inadequate if highly localised AOs are present (known issue of GPW). In these cases, a GAPW-like implementation would be advantageous but a matter of implementation.
  • Loading branch information
fstein93 committed Oct 31, 2023
1 parent 4b4a740 commit b95f545
Show file tree
Hide file tree
Showing 19 changed files with 1,216 additions and 305 deletions.
23 changes: 21 additions & 2 deletions src/input_cp2k_dft.F
Original file line number Diff line number Diff line change
Expand Up @@ -9394,7 +9394,7 @@ SUBROUTINE create_eri_section(section)
CALL keyword_create( &
keyword, __LOCATION__, name="EPS_INTEGRAL", &
description="Accuracy of ERIs that will be stored.", &
usage="EPS_FILTER 1.0E-10 ", type_of_var=real_t, &
usage="EPS_INTEGRAL 1.0E-10 ", type_of_var=real_t, &
default_r_val=1.0E-12_dp)
CALL section_add_keyword(section, keyword)
CALL keyword_release(keyword)
Expand Down Expand Up @@ -9423,6 +9423,14 @@ SUBROUTINE create_eri_gpw(section)
CALL section_add_keyword(section, keyword)
CALL keyword_release(keyword)

CALL keyword_create(keyword, __LOCATION__, name="EPS_FILTER", &
description="Determines a threshold for the sparse matrix multiplications if METHOD "// &
"GPW_HALF_TRANSFORM is used", &
usage="EPS_FILTER 1.0E-9 ", type_of_var=real_t, &
default_r_val=1.0E-9_dp)
CALL section_add_keyword(section, keyword)
CALL keyword_release(keyword)

CALL keyword_create(keyword, __LOCATION__, name="CUTOFF", &
description="The cutoff of the finest grid level in the GPW integration.", &
usage="CUTOFF 300", type_of_var=real_t, &
Expand All @@ -9440,12 +9448,23 @@ SUBROUTINE create_eri_gpw(section)

CALL keyword_create(keyword, __LOCATION__, name="STORE_WFN", &
variants=(/"STORE_WAVEFUNCTION"/), &
description="Strore wavefunction in real space representation for integration.", &
description="Store wavefunction in real space representation for integration.", &
usage="STORE_WFN T", type_of_var=logical_t, &
default_l_val=.TRUE., lone_keyword_l_val=.TRUE.)
CALL section_add_keyword(section, keyword)
CALL keyword_release(keyword)

CALL keyword_create(keyword, __LOCATION__, name="GROUP_SIZE", &
description="Sets the size of a subgroup for ERI calculation, "// &
"each of which with a full set of work grids, arrays or orbitals "// &
"depending on the method of grids (work grids, arrays, orbitals). "// &
"Small numbers reduce communication but increase the memory demands. "// &
"A negative number indicates all processes (default).", &
usage="GROUP_SIZE 2", type_of_var=integer_t, &
default_i_val=-1)
CALL section_add_keyword(section, keyword)
CALL keyword_release(keyword)

CALL keyword_create(keyword, __LOCATION__, name="PRINT_LEVEL", &
variants=(/"IOLEVEL"/), &
description="How much output is written by the individual groups.", &
Expand Down
218 changes: 112 additions & 106 deletions src/mp2_gpw.F

Large diffs are not rendered by default.

158 changes: 64 additions & 94 deletions src/mp2_gpw_method.F

Large diffs are not rendered by default.

369 changes: 288 additions & 81 deletions src/qs_active_space_methods.F

Large diffs are not rendered by default.

75 changes: 55 additions & 20 deletions src/qs_active_space_types.F
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ MODULE qs_active_space_types
USE dbcsr_api, ONLY: dbcsr_csr_destroy,&
dbcsr_csr_p_type,&
dbcsr_p_type
USE input_constants, ONLY: eri_method_gpw_ht
USE kinds, ONLY: default_path_length,&
dp
USE message_passing, ONLY: mp_comm_type
Expand All @@ -41,11 +42,13 @@ MODULE qs_active_space_types
! **************************************************************************************************
TYPE eri_gpw_type
LOGICAL :: redo_poisson = .FALSE.
LOGICAL :: store_wfn = .FALSE.
REAL(KIND=dp) :: cutoff = 0.0_dp
REAL(KIND=dp) :: rel_cutoff = 0.0_dp
REAL(KIND=dp) :: eps_grid = 0.0_dp
REAL(KIND=dp) :: eps_filter = 0.0_dp
INTEGER :: print_level = 0
LOGICAL :: store_wfn = .FALSE.
INTEGER :: group_size = 0
END TYPE eri_gpw_type

TYPE eri_type
Expand Down Expand Up @@ -234,7 +237,7 @@ INTEGER FUNCTION csr_idx_to_combined(i, j, n) RESULT(ij)

ij = (i - 1)*n - ((i - 1)*(i - 2))/2 + (j - i + 1)

CPASSERT(ij <= (n*(n + 1))/2)
CPASSERT(ij <= (n*(n + 1))/2 .AND. 0 <= ij)

END FUNCTION csr_idx_to_combined

Expand Down Expand Up @@ -277,7 +280,8 @@ END SUBROUTINE csr_idx_from_combined
FUNCTION get_irange_csr(nindex, mp_group) RESULT(irange)
USE message_passing, ONLY: mp_comm_type
INTEGER, INTENT(IN) :: nindex
TYPE(mp_comm_type), INTENT(IN) :: mp_group

CLASS(mp_comm_type), INTENT(IN) :: mp_group
INTEGER, DIMENSION(2) :: irange

REAL(KIND=dp) :: rat
Expand Down Expand Up @@ -323,8 +327,8 @@ SUBROUTINE eri_type_eri_foreach(this, nspin, active_orbitals, fobj, spin1, spin2
INTEGER, DIMENSION(:, :), INTENT(IN) :: active_orbitals
INTEGER, OPTIONAL :: spin1, spin2
INTEGER :: i1, i12, i12l, i2, i3, i34, i34l, i4, m1, m2, m3, m4, &
irange(2), irptr, nspin, nindex, nmo
INTEGER, ALLOCATABLE, DIMENSION(:) :: colind
irange(2), irptr, nspin, nindex, nmo, proc, nonzero_elements_local
INTEGER, ALLOCATABLE, DIMENSION(:) :: colind, offsets, nonzero_elements_global
REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: erival
REAL(KIND=dp) :: erint
TYPE(mp_comm_type) :: mp_group
Expand All @@ -339,34 +343,65 @@ SUBROUTINE eri_type_eri_foreach(this, nspin, active_orbitals, fobj, spin1, spin2
ASSOCIATE (eri => this%eri(nspin)%csr_mat, norb => this%norb)
nindex = (norb*(norb + 1))/2
CALL mp_group%set_handle(eri%mp_group%get_handle())
nmo = SIZE(active_orbitals, 1)
! Irrelevant in case of half-transformed integrals
irange = get_irange_csr(nindex, mp_group)
ALLOCATE (erival(nindex), colind(nindex))

nmo = SIZE(active_orbitals, 1)
IF (this%method == eri_method_gpw_ht) THEN
ALLOCATE (offsets(0:mp_group%num_pe - 1), &
nonzero_elements_global(0:mp_group%num_pe - 1))
END IF

DO m1 = 1, nmo
i1 = active_orbitals(m1, spin1)
DO m2 = m1, nmo
i2 = active_orbitals(m2, spin1)
i12 = csr_idx_to_combined(i1, i2, norb)

IF (i12 >= irange(1) .AND. i12 <= irange(2)) THEN
i12l = i12 - irange(1) + 1
irptr = eri%rowptr_local(i12l)
nindex = eri%nzerow_local(i12l)
colind(1:nindex) = eri%colind_local(irptr:irptr + nindex - 1)
erival(1:nindex) = eri%nzval_local%r_dp(irptr:irptr + nindex - 1)
IF (this%method == eri_method_gpw_ht) THEN
! In case of half-transformed integrals, every process might carry integrals of a row
! The number of integrals varies between processes and rows (related to the randomized
! distribution of matrix blocks)

! 1) Collect the amount of local data from each process
nonzero_elements_local = eri%nzerow_local(i12)
CALL mp_group%allgather(nonzero_elements_local, nonzero_elements_global)

! 2) Prepare arrays for communication (calculate the offsets and the total number of elements)
offsets(0) = 0
DO proc = 1, mp_group%num_pe - 1
offsets(proc) = offsets(proc - 1) + nonzero_elements_global(proc - 1)
END DO
nindex = offsets(mp_group%num_pe - 1) + nonzero_elements_global(mp_group%num_pe - 1)
irptr = eri%rowptr_local(i12)

! Exchange actual data
CALL mp_group%allgatherv(eri%colind_local(irptr:irptr + nonzero_elements_local - 1), &
colind(1:nindex), nonzero_elements_global, offsets)
CALL mp_group%allgatherv(eri%nzval_local%r_dp(irptr:irptr + nonzero_elements_local - 1), &
erival(1:nindex), nonzero_elements_global, offsets)
ELSE
erival = 0.0_dp
colind = 0
nindex = 0
! Here, the rows are distributed among the processes such that each process
! carries all integral of a set of rows
IF (i12 >= irange(1) .AND. i12 <= irange(2)) THEN
i12l = i12 - irange(1) + 1
irptr = eri%rowptr_local(i12l)
nindex = eri%nzerow_local(i12l)
colind(1:nindex) = eri%colind_local(irptr:irptr + nindex - 1)
erival(1:nindex) = eri%nzval_local%r_dp(irptr:irptr + nindex - 1)
ELSE
erival = 0.0_dp
colind = 0
nindex = 0
END IF

! Thus, a simple summation is sufficient
CALL mp_group%sum(nindex)
CALL mp_group%sum(colind(1:nindex))
CALL mp_group%sum(erival(1:nindex))
END IF

CALL mp_group%sum(nindex)
CALL mp_group%sum(colind(1:nindex))
CALL mp_group%sum(erival(1:nindex))
CALL mp_group%sync()

DO i34l = 1, nindex
i34 = colind(i34l)
erint = erival(i34l)
Expand Down
19 changes: 19 additions & 0 deletions tests/QS/regtest-as-2/TEST_FILES
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# runs are executed in the same order as in this file
# the second field tells which test should be run in order to compare with the last available output
# e.g. 0 means do not compare anything, running is enough
# 1 compares the last total energy in the file
# for details see cp2k/tools/do_regtest
#
h2_gpw_nostore.inp 1 1e-12 -1.12622646044780
h2_gpw_nostore.inp 92 1e-8 4.08481739
h2_gpw_ht.inp 1 1e-12 -1.12622646044780
h2_gpw_ht.inp 92 1e-8 4.08599253
h2_gpw_ht_nostore.inp 1 1e-12 -1.12622646044780
h2_gpw_ht_nostore.inp 92 1e-8 4.08599253
h2_gpw_nostore_group.inp 1 1e-12 -1.12622646044780
h2_gpw_nostore_group.inp 92 1e-8 4.08481739
h2_gpw_ht_group.inp 1 1e-12 -1.12622646044780
h2_gpw_ht_group.inp 92 1e-8 4.08599253
h2_gpw_ht_nostore_group.inp 1 1e-12 -1.12622646044780
h2_gpw_ht_nostore_group.inp 92 1e-8 4.08599253
#EOF
70 changes: 70 additions & 0 deletions tests/QS/regtest-as-2/h2_gpw_ht.inp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
&FORCE_EVAL
METHOD Quickstep
&DFT
&QS
METHOD GPW
&END QS
&XC
&XC_FUNCTIONAL NONE
&END XC_FUNCTIONAL
&HF 1.0
&END HF
&END XC
&POISSON
POISSON_SOLVER ANALYTIC
PERIODIC NONE
&END POISSON
&MGRID
CUTOFF 500
&END MGRID
&SCF
ADDED_MOS 3
MAX_SCF 10
&PRINT
&RESTART OFF
&END RESTART
&END PRINT
&END SCF
&PRINT
&AO_MATRICES
CORE_HAMILTONIAN TRUE
KINETIC_ENERGY TRUE
POTENTIAL_ENERGY TRUE
&END AO_MATRICES
&END PRINT
&ACTIVE_SPACE
ACTIVE_ELECTRONS 2
ACTIVE_ORBITALS 2
ISOLATED_SYSTEM TRUE
&ERI
METHOD GPW_HALF_TRANSFORM
PERIODICITY 0 0 0
&END ERI
&ERI_GPW
CUTOFF 500
&END ERI_GPW
&FCIDUMP
FILENAME __STD_OUT__
&END FCIDUMP
&END ACTIVE_SPACE
&END DFT
&SUBSYS
&CELL
ABC 6.0 6.0 6.0
PERIODIC NONE
&END CELL
&COORD
H 0.000 0.000 0.356
H 0.000 0.000 -0.356
&END COORD
&KIND H
BASIS_SET 6-31G*
POTENTIAL GTH-HF
&END KIND
&END SUBSYS
&END FORCE_EVAL
&GLOBAL
RUN_TYPE ENERGY
PRINT_LEVEL LOW
PROJECT h2_gpw_ht
&END GLOBAL
71 changes: 71 additions & 0 deletions tests/QS/regtest-as-2/h2_gpw_ht_group.inp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
&FORCE_EVAL
METHOD Quickstep
&DFT
&QS
METHOD GPW
&END QS
&XC
&XC_FUNCTIONAL NONE
&END XC_FUNCTIONAL
&HF 1.0
&END HF
&END XC
&POISSON
POISSON_SOLVER ANALYTIC
PERIODIC NONE
&END POISSON
&MGRID
CUTOFF 500
&END MGRID
&SCF
ADDED_MOS 3
MAX_SCF 10
&PRINT
&RESTART OFF
&END RESTART
&END PRINT
&END SCF
&PRINT
&AO_MATRICES
CORE_HAMILTONIAN TRUE
KINETIC_ENERGY TRUE
POTENTIAL_ENERGY TRUE
&END AO_MATRICES
&END PRINT
&ACTIVE_SPACE
ACTIVE_ELECTRONS 2
ACTIVE_ORBITALS 2
ISOLATED_SYSTEM TRUE
&ERI
METHOD GPW_HALF_TRANSFORM
PERIODICITY 0 0 0
&END ERI
&ERI_GPW
CUTOFF 500
GROUP_SIZE 1
&END ERI_GPW
&FCIDUMP
FILENAME __STD_OUT__
&END FCIDUMP
&END ACTIVE_SPACE
&END DFT
&SUBSYS
&CELL
ABC 6.0 6.0 6.0
PERIODIC NONE
&END CELL
&COORD
H 0.000 0.000 0.356
H 0.000 0.000 -0.356
&END COORD
&KIND H
BASIS_SET 6-31G*
POTENTIAL GTH-HF
&END KIND
&END SUBSYS
&END FORCE_EVAL
&GLOBAL
RUN_TYPE ENERGY
PRINT_LEVEL LOW
PROJECT h2_gpw_ht_group
&END GLOBAL

0 comments on commit b95f545

Please sign in to comment.