Skip to content

Commit

Permalink
Smarter buffer allocation for LIBXSMM 3-center contraction
Browse files Browse the repository at this point in the history
  • Loading branch information
abussy committed Apr 24, 2020
1 parent f2d83e1 commit 91d9f7f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 31 deletions.
34 changes: 21 additions & 13 deletions src/qs_tensors.F
Original file line number Diff line number Diff line change
Expand Up @@ -861,9 +861,9 @@ SUBROUTINE build_3c_integrals(t3c, filter_eps, qs_env, &

INTEGER :: block_end_i, block_end_j, block_end_k, block_start_i, block_start_j, &
block_start_k, egfi, handle, handle2, i, iatom, ibasis, ikind, ilist, imax, iset, jatom, &
jcell, jkind, jset, katom, kcell, kkind, kset, m_max, max_nco, max_nset, max_nsgf, maxli, &
maxlj, maxlk, natom, nbasis, ncoi, ncoj, ncok, nimg, nseti, nsetj, nsetk, op_ij, op_jk, &
op_pos_prv, sgfi, sgfj, sgfk, unit_id
jcell, jkind, jset, katom, kcell, kkind, kset, m_max, max_ncoi, max_ncoj, max_ncok, &
max_nset, max_nsgfi, max_nsgfj, max_nsgfk, maxli, maxlj, maxlk, natom, nbasis, ncoi, &
ncoj, ncok, nimg, nseti, nsetj, nsetk, op_ij, op_jk, op_pos_prv, sgfi, sgfj, sgfk, unit_id
INTEGER, DIMENSION(3) :: blk_size, cell_j, cell_k, &
kp_index_lbounds, kp_index_ubounds, sp
INTEGER, DIMENSION(:), POINTER :: lmax_i, lmax_j, lmax_k, lmin_i, lmin_j, &
Expand Down Expand Up @@ -964,43 +964,51 @@ SUBROUTINE build_3c_integrals(t3c, filter_eps, qs_env, &

!Need the max l for each basis for libint and max nset, nco and nsgf for LIBXSMM contraction
nbasis = SIZE(basis_i)
max_nsgf = 0
max_nco = 0
max_nsgfi = 0
max_ncoi = 0
max_nset = 0
maxli = 0
DO ibasis = 1, nbasis
CALL get_gto_basis_set(gto_basis_set=basis_i(ibasis)%gto_basis_set, maxl=imax, &
nset=iset, nsgf_set=nsgfi, npgf=npgfi)
maxli = MAX(maxli, imax)
max_nset = MAX(max_nset, iset)
max_nsgf = MAX(max_nsgf, MAXVAL(nsgfi))
max_nco = MAX(max_nco, MAXVAL(npgfi)*ncoset(maxli))
max_nsgfi = MAX(max_nsgfi, MAXVAL(nsgfi))
max_ncoi = MAX(max_ncoi, MAXVAL(npgfi)*ncoset(maxli))
END DO
max_nsgfj = 0
max_ncoj = 0
maxlj = 0
DO ibasis = 1, nbasis
CALL get_gto_basis_set(gto_basis_set=basis_j(ibasis)%gto_basis_set, maxl=imax, &
nset=jset, nsgf_set=nsgfj, npgf=npgfj)
maxlj = MAX(maxlj, imax)
max_nset = MAX(max_nset, jset)
max_nsgf = MAX(max_nsgf, MAXVAL(nsgfj))
max_nco = MAX(max_nco, MAXVAL(npgfj)*ncoset(maxlj))
max_nsgfj = MAX(max_nsgfj, MAXVAL(nsgfj))
max_ncoj = MAX(max_ncoj, MAXVAL(npgfj)*ncoset(maxlj))
END DO
max_nsgfk = 0
max_ncok = 0
maxlk = 0
DO ibasis = 1, nbasis
CALL get_gto_basis_set(gto_basis_set=basis_k(ibasis)%gto_basis_set, maxl=imax, &
nset=kset, nsgf_set=nsgfk, npgf=npgfk)
maxlk = MAX(maxlk, imax)
max_nset = MAX(max_nset, kset)
max_nsgf = MAX(max_nsgf, MAXVAL(nsgfk))
max_nco = MAX(max_nco, MAXVAL(npgfk)*ncoset(maxlk))
max_nsgfk = MAX(max_nsgfk, MAXVAL(nsgfk))
max_ncok = MAX(max_ncok, MAXVAL(npgfk)*ncoset(maxlk))
END DO
m_max = maxli + maxlj + maxlk

!To minimize expensive memory opsand generally optimize contraction, pre-allocate buffers and
!contiguous sphi arrays (and transposed in the cas of sphi_i)
ALLOCATE (cpp_buffer(max_nsgf*max_nco), ccp_buffer(max_nsgf*max_nsgf*max_nco))
NULLIFY (tspi, tspj, spi, spj, spk)
IF (op_ij /= do_potential_id) THEN
ALLOCATE (cpp_buffer(max_nsgfj*max_ncok), ccp_buffer(max_nsgfj*max_nsgfk*max_ncoi))
ELSE
ALLOCATE (cpp_buffer(max_nsgfi*max_ncoj), ccp_buffer(max_nsgfi*max_nsgfj*max_ncok))
END IF

NULLIFY (tspi, tspj, spi, spj, spk)
IF (op_ij /= do_potential_id) THEN
ALLOCATE (spi(max_nset, nbasis), tspj(max_nset, nbasis), spk(max_nset, nbasis))
ELSE
Expand Down
37 changes: 19 additions & 18 deletions src/xas_tdp_integrals.F
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ SUBROUTINE fill_pqX_tensor(pq_X, ab_nl, ac_nl, basis_set_list_a, basis_set_list_
routineP = moduleN//':'//routineN
INTEGER :: egfa, egfb, egfc, handle, i, iatom, ibasis, ikind, ilist, imax, iset, jatom, &
jkind, jset, katom, kkind, kset, m_max, max_nco, max_nset, max_nsgf, maxli, maxlj, maxlk, &
mepos, nbasis, ncoa, ncob, ncoc, ni, nj, nk, nseta, nsetb, nsetc, nthread, sgfa, sgfb, &
sgfc, unit_id
jkind, jset, katom, kkind, kset, m_max, max_ncob, max_ncoc, max_nset, max_nsgfa, &
max_nsgfb, maxli, maxlj, maxlk, mepos, nbasis, ncoa, ncob, ncoc, ni, nj, nk, nseta, &
nsetb, nsetc, nthread, sgfa, sgfb, sgfc, unit_id
INTEGER, DIMENSION(:), POINTER :: la_max, la_min, lb_max, lb_min, lc_max, &
lc_min, npgfa, npgfb, npgfc, nsgfa, &
nsgfb, nsgfc
Expand Down Expand Up @@ -315,35 +315,35 @@ SUBROUTINE fill_pqX_tensor(pq_X, ab_nl, ac_nl, basis_set_list_a, basis_set_list_
!Need the max l for each basis for libint (and overall max #of sets for screening)
nbasis = SIZE(basis_set_list_a)
max_nsgf = 0
max_nco = 0
max_nsgfa = 0
max_nset = 0
maxli = 0
DO ibasis = 1, nbasis
CALL get_gto_basis_set(gto_basis_set=basis_set_list_a(ibasis)%gto_basis_set, &
maxl=imax, nset=iset, nsgf_set=nsgfa, npgf=npgfa)
maxl=imax, nset=iset, nsgf_set=nsgfa)
maxli = MAX(maxli, imax)
max_nset = MAX(max_nset, iset)
max_nsgf = MAX(max_nsgf, MAXVAL(nsgfa))
max_nco = MAX(max_nco, MAXVAL(npgfa)*ncoset(maxli))
max_nsgfa = MAX(max_nsgfa, MAXVAL(nsgfa))
END DO
max_nsgfb = 0
max_ncob = 0
maxlj = 0
DO ibasis = 1, nbasis
CALL get_gto_basis_set(gto_basis_set=basis_set_list_b(ibasis)%gto_basis_set, &
maxl=imax, nset=iset, nsgf_set=nsgfb, npgf=npgfb)
maxlj = MAX(maxlj, imax)
max_nset = MAX(max_nset, iset)
max_nsgf = MAX(max_nsgf, MAXVAL(nsgfb))
max_nco = MAX(max_nco, MAXVAL(npgfb)*ncoset(maxlj))
max_nsgfb = MAX(max_nsgfb, MAXVAL(nsgfb))
max_ncob = MAX(max_ncob, MAXVAL(npgfb)*ncoset(maxlj))
END DO
maxlk = 0
max_ncoc = 0
DO ibasis = 1, nbasis
CALL get_gto_basis_set(gto_basis_set=basis_set_list_c(ibasis)%gto_basis_set, &
maxl=imax, nset=iset, nsgf_set=nsgfc, npgf=npgfc)
maxl=imax, nset=iset, npgf=npgfc)
maxlk = MAX(maxlk, imax)
max_nset = MAX(max_nset, iset)
max_nsgf = MAX(max_nsgf, MAXVAL(nsgfc))
max_nco = MAX(max_nco, MAXVAL(npgfc)*ncoset(maxlk))
max_ncoc = MAX(max_ncoc, MAXVAL(npgfc)*ncoset(maxlk))
END DO
m_max = maxli + maxlj + maxlk
Expand Down Expand Up @@ -474,9 +474,10 @@ SUBROUTINE fill_pqX_tensor(pq_X, ab_nl, ac_nl, basis_set_list_a, basis_set_list_
CALL o3c_iterator_create(o3c, o3c_iterator, nthread=nthread)
!$OMP PARALLEL DEFAULT(NONE) &
!$OMP SHARED (pq_X,do_screen,max_nset,basis_set_list_a,max_contra,max_contrb,max_contrc,max_nsgf,&
!$OMP basis_set_list_b, basis_set_list_c,ncoset,screen_radius,potential_parameter,max_nco,&
!$OMP my_eps_screen,maxli,maxlj,maxlk,my_sort_bc,nthread,o3c,o3c_iterator,tspa,spb,spc) &
!$OMP SHARED (pq_X,do_screen,max_nset,basis_set_list_a,max_contra,max_contrb,max_contrc,max_nsgfa,&
!$OMP basis_set_list_b, basis_set_list_c,ncoset,screen_radius,potential_parameter,max_ncob,&
!$OMP my_eps_screen,maxli,maxlj,maxlk,my_sort_bc,nthread,o3c,o3c_iterator,tspa,spb,spc,&
!$OMP max_ncoc,max_nsgfb) &
!$OMP PRIVATE (lib,i,mepos,work,iset,ncoa,sgfa,egfa,nseta,&
!$OMP iatom,ikind,jatom,jkind,katom,kkind,rij,rik,rjk,basis_set_a,nsetb,&
!$OMP la_max,la_min,lb_max,lb_min,lc_max,lc_min,npgfa,npgfb,npgfc,nsgfa,nsgfb,nsgfc,ri,rk,&
Expand All @@ -488,8 +489,8 @@ SUBROUTINE fill_pqX_tensor(pq_X, ab_nl, ac_nl, basis_set_list_a, basis_set_list_
!$ mepos = omp_get_thread_num()
!pre-allocate work buffers for LIBXSMM contract in order to avoid memory ops
ALLOCATE (cpp_buffer(max_nsgf*max_nco))
ALLOCATE (ccp_buffer(max_nsgf*max_nsgf*max_nco))
ALLOCATE (cpp_buffer(max_nsgfa*max_ncob))
ALLOCATE (ccp_buffer(max_nsgfa*max_nsgfb*max_ncoc))
!note: we do not initalize libxsmm here, because we assume that if the flag is there, then it
! is done in dbcsr already
Expand Down

0 comments on commit 91d9f7f

Please sign in to comment.