Skip to content

Commit

Permalink
Remove DBCSR backend for dense matrix multiplications
Browse files Browse the repository at this point in the history
  • Loading branch information
alazzaro committed Jul 21, 2021
1 parent eeb0f8b commit f8023de
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 204 deletions.
166 changes: 67 additions & 99 deletions src/cp_gemm_interface.F
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,18 @@
!> \author Fawzi Mohamed
! **************************************************************************************************
MODULE cp_gemm_interface
USE ISO_C_BINDING, ONLY: C_CHAR, &
C_DOUBLE, &
C_INT, &
C_LOC, &
C_PTR
USE cp_dbcsr_operations, ONLY: copy_dbcsr_to_fm_bc, &
copy_fm_to_dbcsr_bc
USE cp_fm_basic_linalg, ONLY: cp_fm_gemm
USE cp_fm_types, ONLY: cp_fm_get_info, &
cp_fm_get_mm_type, &
cp_fm_type
USE dbcsr_api, ONLY: dbcsr_multiply, &
dbcsr_release, &
dbcsr_type
USE input_constants, ONLY: do_cosma, &
do_dbcsr, &
do_scalapack
USE kinds, ONLY: dp
USE message_passing, ONLY: mp_min
USE offload_api, ONLY: offload_set_device
USE string_utilities, ONLY: uppercase
USE ISO_C_BINDING, ONLY: C_CHAR,&
C_DOUBLE,&
C_INT,&
C_LOC,&
C_PTR
USE cp_fm_basic_linalg, ONLY: cp_fm_gemm
USE cp_fm_types, ONLY: cp_fm_get_mm_type,&
cp_fm_type
USE input_constants, ONLY: do_cosma,&
do_scalapack
USE kinds, ONLY: dp
USE offload_api, ONLY: offload_set_device
#include "./base/base_uses.f90"

IMPLICIT NONE
Expand Down Expand Up @@ -77,73 +68,15 @@ SUBROUTINE cp_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &

CHARACTER(len=*), PARAMETER :: routineN = 'cp_gemm'

CHARACTER(LEN=1) :: my_trans
INTEGER :: handle, handle1, my_multi
INTEGER, PARAMETER :: &
#if defined(__COSMA)
my_multi_fallback = do_cosma
#else
my_multi_fallback = do_scalapack
#endif
INTEGER, DIMENSION(:), POINTER :: a_col_loc, a_row_loc, b_col_loc, &
b_row_loc, c_col_loc, c_row_loc
TYPE(dbcsr_type) :: a_db, b_db, c_db

CALL timeset(routineN, handle)

my_multi = cp_fm_get_mm_type()

! catch the special case that matrices have different blocking
! SCALAPACK/COSMA can deal with it but dbcsr doesn't like it
CALL cp_fm_get_info(matrix_a, nrow_locals=a_row_loc, ncol_locals=a_col_loc)
CALL cp_fm_get_info(matrix_b, nrow_locals=b_row_loc, ncol_locals=b_col_loc)
CALL cp_fm_get_info(matrix_c, nrow_locals=c_row_loc, ncol_locals=c_col_loc)
IF (PRESENT(a_first_row)) my_multi = my_multi_fallback
IF (PRESENT(a_first_col)) my_multi = my_multi_fallback
IF (PRESENT(b_first_row)) my_multi = my_multi_fallback
IF (PRESENT(b_first_col)) my_multi = my_multi_fallback
IF (PRESENT(c_first_row)) my_multi = my_multi_fallback
IF (PRESENT(c_first_col)) my_multi = my_multi_fallback
my_trans = transa; CALL uppercase(my_trans)
IF (my_trans == 'T') THEN
CALL cp_fm_get_info(matrix_a, nrow_locals=a_col_loc, ncol_locals=a_row_loc)
END IF
my_trans = transb; CALL uppercase(my_trans)
IF (my_trans == 'T') THEN
CALL cp_fm_get_info(matrix_b, nrow_locals=b_col_loc, ncol_locals=b_row_loc)
END IF
IF (my_multi .NE. do_scalapack .AND. my_multi .NE. do_cosma) THEN
IF (SIZE(a_row_loc) == SIZE(c_row_loc)) THEN
IF (ANY(a_row_loc - c_row_loc .NE. 0)) my_multi = my_multi_fallback
ELSE
my_multi = my_multi_fallback
END IF
END IF
IF (my_multi .NE. do_scalapack .AND. my_multi .NE. do_cosma) THEN
IF (SIZE(b_col_loc) == SIZE(c_col_loc)) THEN
IF (ANY(b_col_loc - c_col_loc .NE. 0)) my_multi = my_multi_fallback
ELSE
my_multi = my_multi_fallback
END IF
END IF
IF (my_multi .NE. do_scalapack .AND. my_multi .NE. do_cosma) THEN
IF (SIZE(a_col_loc) == SIZE(b_row_loc)) THEN
IF (ANY(a_col_loc - b_row_loc .NE. 0)) my_multi = my_multi_fallback
ELSE
my_multi = my_multi_fallback
END IF
END IF
! IMPORTANT do_scalapack is lowest value. If one processor has it set make all use it.
IF (cp_fm_get_mm_type() .NE. do_scalapack .AND. &
cp_fm_get_mm_type() .NE. do_cosma) CALL mp_min(my_multi, matrix_a%matrix_struct%para_env%group)
SELECT CASE (my_multi)
CASE (do_scalapack)
CALL timeset("cp_gemm_fm_gemm", handle1)
CALL timeset(routineN//"_fm_gemm", handle1)
CALL cp_fm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
a_first_col=a_first_col, &
a_first_row=a_first_row, &
Expand All @@ -154,27 +87,20 @@ SUBROUTINE cp_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
CALL timestop(handle1)
CASE (do_cosma)
#if defined(__COSMA)
CALL timeset("cp_gemm_cosma", handle1)
CALL timeset(routineN//"_cosma", handle1)
CALL offload_set_device()
CALL cosma_pdgemm(transa=transa, transb=transb, m=m, n=n, k=k, alpha=alpha, &
matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c)
matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c, &
a_first_col=a_first_col, &
a_first_row=a_first_row, &
b_first_col=b_first_col, &
b_first_row=b_first_row, &
c_first_col=c_first_col, &
c_first_row=c_first_row)
CALL timestop(handle1)
#else
CPABORT("CP2K compiled without the COSMA library.")
#endif
CASE (do_dbcsr)
CALL timeset("cp_gemm_dbcsr_mm", handle1)
CALL copy_fm_to_dbcsr_bc(matrix_a, a_db)
CALL copy_fm_to_dbcsr_bc(matrix_b, b_db)
CALL copy_fm_to_dbcsr_bc(matrix_c, c_db)
CALL dbcsr_multiply(transa, transb, alpha, a_db, b_db, beta, c_db, last_k=k)
CALL copy_dbcsr_to_fm_bc(c_db, matrix_c)
CALL dbcsr_release(a_db)
CALL dbcsr_release(b_db)
CALL dbcsr_release(c_db)
CALL timestop(handle1)
END SELECT
CALL timestop(handle)

Expand All @@ -193,16 +119,27 @@ END SUBROUTINE cp_gemm
!> \param matrix_b ...
!> \param beta ...
!> \param matrix_c ...
!> \param a_first_col ...
!> \param a_first_row ...
!> \param b_first_col ...
!> \param b_first_row ...
!> \param c_first_col ...
!> \param c_first_row ...
!> \author Ole Schuett
! **************************************************************************************************
SUBROUTINE cosma_pdgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c)
SUBROUTINE cosma_pdgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
a_first_col, a_first_row, b_first_col, b_first_row, &
c_first_col, c_first_row)
CHARACTER(LEN=1), INTENT(IN) :: transa, transb
INTEGER, INTENT(IN) :: m, n, k
REAL(KIND=dp), INTENT(IN) :: alpha
TYPE(cp_fm_type), POINTER :: matrix_a, matrix_b
REAL(KIND=dp), INTENT(IN) :: beta
TYPE(cp_fm_type), POINTER :: matrix_c
INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
b_first_row, c_first_col, c_first_row

INTEGER :: i_a, i_b, i_c, j_a, j_b, j_c
INTERFACE
SUBROUTINE cosma_pdgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
b, ib, jb, descb, beta, c, ic, jc, descc) &
Expand Down Expand Up @@ -230,14 +167,45 @@ SUBROUTINE cosma_pdgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
END SUBROUTINE cosma_pdgemm_c
END INTERFACE

IF (PRESENT(a_first_row)) THEN
i_a = a_first_row
ELSE
i_a = 1
END IF
IF (PRESENT(a_first_col)) THEN
j_a = a_first_col
ELSE
j_a = 1
END IF
IF (PRESENT(b_first_row)) THEN
i_b = b_first_row
ELSE
i_b = 1
END IF
IF (PRESENT(b_first_col)) THEN
j_b = b_first_col
ELSE
j_b = 1
END IF
IF (PRESENT(c_first_row)) THEN
i_c = c_first_row
ELSE
i_c = 1
END IF
IF (PRESENT(c_first_col)) THEN
j_c = c_first_col
ELSE
j_c = 1
END IF

CALL cosma_pdgemm_c(transa=transa, transb=transb, m=m, n=n, k=k, &
alpha=alpha, &
a=C_LOC(matrix_a%local_data(1, 1)), ia=1, ja=1, &
a=C_LOC(matrix_a%local_data(1, 1)), ia=i_a, ja=j_a, &
desca=C_LOC(matrix_a%matrix_struct%descriptor(1)), &
b=C_LOC(matrix_b%local_data(1, 1)), ib=1, jb=1, &
b=C_LOC(matrix_b%local_data(1, 1)), ib=i_b, jb=j_b, &
descb=C_LOC(matrix_b%matrix_struct%descriptor(1)), &
beta=beta, &
c=C_LOC(matrix_c%local_data(1, 1)), ic=1, jc=1, &
c=C_LOC(matrix_c%local_data(1, 1)), ic=i_c, jc=j_c, &
descc=C_LOC(matrix_c%matrix_struct%descriptor(1)))

END SUBROUTINE cosma_pdgemm
Expand Down
9 changes: 3 additions & 6 deletions src/environment.F
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ MODULE environment
USE header, ONLY: cp2k_footer,&
cp2k_header
USE input_constants, ONLY: &
callgraph_all, callgraph_none, do_cosma, do_cp2k, do_dbcsr, do_diag_elpa, &
do_diag_scalapack, do_eip, do_farming, do_fft_fftw3, do_fft_sg, do_fist, do_qs, &
do_scalapack, do_sirius, do_test, energy_run, id_development_version, mol_dyn_run, none_run
callgraph_all, callgraph_none, do_cosma, do_cp2k, do_diag_elpa, do_diag_scalapack, do_eip, &
do_farming, do_fft_fftw3, do_fft_sg, do_fist, do_qs, do_scalapack, do_sirius, do_test, &
energy_run, id_development_version, mol_dyn_run, none_run
USE input_cp2k_global, ONLY: create_global_section
USE input_enumeration_types, ONLY: enum_i2c,&
enumeration_type
Expand Down Expand Up @@ -794,9 +794,6 @@ SUBROUTINE read_global_section(root_section, para_env, globenv)
CASE (do_scalapack)
WRITE (UNIT=output_unit, FMT="(T2,A,T72,A)") &
start_section_label//"| Matrix multiplication library", "ScaLAPACK"
CASE (do_dbcsr)
WRITE (UNIT=output_unit, FMT="(T2,A,T76,A)") &
start_section_label//"| Matrix multiplication library", "DBCSR"
CASE (do_cosma)
WRITE (UNIT=output_unit, FMT="(T2,A,T76,A)") &
start_section_label//"| Matrix multiplication library", "COSMA"
Expand Down
3 changes: 1 addition & 2 deletions src/input_constants.F
Original file line number Diff line number Diff line change
Expand Up @@ -1086,8 +1086,7 @@ MODULE input_constants

! fm matrix multiplication
INTEGER, PARAMETER, PUBLIC :: do_scalapack = 1, &
do_cosma = 2, &
do_dbcsr = 3
do_cosma = 2

! Dispersion DFTB
INTEGER, PARAMETER, PUBLIC :: dispersion_uff = 100, &
Expand Down
18 changes: 8 additions & 10 deletions src/input_cp2k_global.F
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ MODULE input_cp2k_global
GRID_BACKEND_REF
USE input_constants, ONLY: &
bsse_run, callgraph_all, callgraph_master, callgraph_none, cell_opt_run, debug_run, &
do_atom, do_band, do_cosma, do_cp2k, do_dbcsr, do_diag_elpa, do_diag_scalapack, &
do_farming, do_fft_fftw3, do_fft_sg, do_opt_basis, do_optimize_input, do_scalapack, &
do_swarm, do_tamc, do_test, do_tree_mc, do_tree_mc_ana, driver_run, ehrenfest, &
electronic_spectra_run, energy_force_run, energy_run, fftw_plan_estimate, &
fftw_plan_exhaustive, fftw_plan_measure, fftw_plan_patient, gaussian, geo_opt_run, &
linear_response_run, mol_dyn_run, mon_car_run, negf_run, none_run, pint_run, &
real_time_propagation, tree_mc_run, vib_anal
do_atom, do_band, do_cosma, do_cp2k, do_diag_elpa, do_diag_scalapack, do_farming, &
do_fft_fftw3, do_fft_sg, do_opt_basis, do_optimize_input, do_scalapack, do_swarm, do_tamc, &
do_test, do_tree_mc, do_tree_mc_ana, driver_run, ehrenfest, electronic_spectra_run, &
energy_force_run, energy_run, fftw_plan_estimate, fftw_plan_exhaustive, fftw_plan_measure, &
fftw_plan_patient, gaussian, geo_opt_run, linear_response_run, mol_dyn_run, mon_car_run, &
negf_run, none_run, pint_run, real_time_propagation, tree_mc_run, vib_anal
USE input_keyword_types, ONLY: keyword_create,&
keyword_release,&
keyword_type
Expand Down Expand Up @@ -639,11 +638,10 @@ SUBROUTINE create_fm_section(section)
"FORCE_BLOCK_SIZE should be set. The performance on GPU's depends "// &
"crucially on the BLOCK_SIZES. Make sure optimized kernels are available.", &
default_i_val=default_matmul, &
enum_i_vals=(/do_scalapack, do_scalapack, do_dbcsr, do_cosma/), &
enum_c_vals=s2a("SCALAPACK", "PDGEMM", "DBCSR_MM", "COSMA"), &
enum_i_vals=(/do_scalapack, do_scalapack, do_cosma/), &
enum_c_vals=s2a("SCALAPACK", "PDGEMM", "COSMA"), &
enum_desc=s2a("Standard ScaLAPACK pdgemm", &
"Alias for ScaLAPACK", &
"DBCSR_MM is employed. This needs local transformation of the matrices", &
"COSMA is employed. See https://github.com/eth-cscs/COSMA."))
CALL section_add_keyword(section, keyword)
CALL keyword_release(keyword)
Expand Down
86 changes: 0 additions & 86 deletions tests/QS/regtest-rtp-3/H2O_rtp_dbcsr_gemm.inp

This file was deleted.

1 change: 0 additions & 1 deletion tests/QS/regtest-rtp-3/TEST_FILES
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,5 @@ H2O-delta-01.inp 1 2e-12
H2O-delta-02.inp 1 1e-13 -17.17819891050733
H2O-delta-03.inp 1 1e-13 -16.81088185881686
H2O-delta-04.inp 1 4e-14 -17.17819854564585
H2O_rtp_dbcsr_gemm.inp 2 1.0E-14 -0.171661642587E+02
H2O_added_mos_emd.inp 1 3e-14 -17.16616425765393
#EOF

0 comments on commit f8023de

Please sign in to comment.