Skip to content

Commit

Permalink
COSMA: Switch to prefixed_pxgemm API and upgrade to v2.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
oschuett committed May 27, 2021
1 parent 907ab74 commit 8ad318d
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 75 deletions.
7 changes: 3 additions & 4 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,9 @@ SIRIUS is a domain specific library for electronic structure calculations.

### 2s. COSMA (Distributed Communication-Optimal Matrix-Matrix Multiplication Algorithm)

- COSMA is a replacement of the pdgemm routine included in scalapack. The
library supports both CPU and GPUs. No specific flag during compilation is
needed to use the library in cp2k, excepted during linking time where the
library should be placed in front of the scalapack library.
- COSMA is an alternative for the pdgemm routine included in ScaLAPACK.
The library supports both CPU and GPUs.
- Add `-D__COSMA` to the DFLAGS to enable support for COSMA.
- see <https://github.com/eth-cscs/COSMA> for more information.

### 2t. LibVori (Voronoi Integration for Electrostatic Properties from Electron Density)
Expand Down
3 changes: 3 additions & 0 deletions src/cp2k_info.F
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ FUNCTION cp2k_flags() RESULT(flags)
#if defined(__SCALAPACK)
flags = TRIM(flags)//" scalapack"
#endif
#if defined(__COSMA)
flags = TRIM(flags)//" cosma"
#endif

#if defined(__QUIP)
flags = TRIM(flags)//" quip"
Expand Down
121 changes: 101 additions & 20 deletions src/cp_gemm_interface.F
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
!> \author Fawzi Mohamed
! **************************************************************************************************
MODULE cp_gemm_interface
USE ISO_C_BINDING, ONLY: C_CHAR,&
C_DOUBLE,&
C_INT,&
C_LOC,&
C_PTR
USE cp_dbcsr_operations, ONLY: copy_dbcsr_to_fm_bc,&
copy_fm_to_dbcsr_bc
USE cp_fm_basic_linalg, ONLY: cp_fm_gemm
Expand All @@ -21,10 +26,12 @@ MODULE cp_gemm_interface
USE dbcsr_api, ONLY: dbcsr_multiply,&
dbcsr_release,&
dbcsr_type
USE input_constants, ONLY: do_dbcsr,&
do_pdgemm
USE input_constants, ONLY: do_cosma,&
do_dbcsr,&
do_scalapack
USE kinds, ONLY: dp
USE message_passing, ONLY: mp_min
USE offload_api, ONLY: offload_set_device
USE string_utilities, ONLY: uppercase
#include "./base/base_uses.f90"

Expand Down Expand Up @@ -79,17 +86,18 @@ SUBROUTINE cp_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
CALL timeset(routineN, handle)

my_multi = cp_fm_get_mm_type()

! catch the special case that matrices have different blocking
! SCALAPACK can deal with it but dbcsr doesn't like it
CALL cp_fm_get_info(matrix_a, nrow_locals=a_row_loc, ncol_locals=a_col_loc)
CALL cp_fm_get_info(matrix_b, nrow_locals=b_row_loc, ncol_locals=b_col_loc)
CALL cp_fm_get_info(matrix_c, nrow_locals=c_row_loc, ncol_locals=c_col_loc)
IF (PRESENT(a_first_row)) my_multi = do_pdgemm
IF (PRESENT(a_first_col)) my_multi = do_pdgemm
IF (PRESENT(b_first_row)) my_multi = do_pdgemm
IF (PRESENT(b_first_col)) my_multi = do_pdgemm
IF (PRESENT(c_first_row)) my_multi = do_pdgemm
IF (PRESENT(c_first_col)) my_multi = do_pdgemm
IF (PRESENT(a_first_row)) my_multi = do_scalapack
IF (PRESENT(a_first_col)) my_multi = do_scalapack
IF (PRESENT(b_first_row)) my_multi = do_scalapack
IF (PRESENT(b_first_col)) my_multi = do_scalapack
IF (PRESENT(c_first_row)) my_multi = do_scalapack
IF (PRESENT(c_first_col)) my_multi = do_scalapack
my_trans = transa; CALL uppercase(my_trans)
IF (my_trans == 'T') THEN
Expand All @@ -101,33 +109,33 @@ SUBROUTINE cp_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
CALL cp_fm_get_info(matrix_b, nrow_locals=b_col_loc, ncol_locals=b_row_loc)
END IF
IF (my_multi .NE. do_pdgemm) THEN
IF (my_multi .NE. do_scalapack) THEN
IF (SIZE(a_row_loc) == SIZE(c_row_loc)) THEN
IF (ANY(a_row_loc - c_row_loc .NE. 0)) my_multi = do_pdgemm
IF (ANY(a_row_loc - c_row_loc .NE. 0)) my_multi = do_scalapack
ELSE
my_multi = do_pdgemm
my_multi = do_scalapack
END IF
END IF
IF (my_multi .NE. do_pdgemm) THEN
IF (my_multi .NE. do_scalapack) THEN
IF (SIZE(b_col_loc) == SIZE(c_col_loc)) THEN
IF (ANY(b_col_loc - c_col_loc .NE. 0)) my_multi = do_pdgemm
IF (ANY(b_col_loc - c_col_loc .NE. 0)) my_multi = do_scalapack
ELSE
my_multi = do_pdgemm
my_multi = do_scalapack
END IF
END IF
IF (my_multi .NE. do_pdgemm) THEN
IF (my_multi .NE. do_scalapack) THEN
IF (SIZE(a_col_loc) == SIZE(b_row_loc)) THEN
IF (ANY(a_col_loc - b_row_loc .NE. 0)) my_multi = do_pdgemm
IF (ANY(a_col_loc - b_row_loc .NE. 0)) my_multi = do_scalapack
ELSE
my_multi = do_pdgemm
my_multi = do_scalapack
END IF
END IF
! IMPORTANT do_pdgemm is lowest value. If one processor has it set make all do pdgemm
IF (cp_fm_get_mm_type() .NE. do_pdgemm) CALL mp_min(my_multi, matrix_a%matrix_struct%para_env%group)
! IMPORTANT do_scalapack is lowest value. If one processor has it set make all use it.
IF (cp_fm_get_mm_type() .NE. do_scalapack) CALL mp_min(my_multi, matrix_a%matrix_struct%para_env%group)
SELECT CASE (my_multi)
CASE (do_pdgemm)
CASE (do_scalapack)
CALL timeset("cp_gemm_fm_gemm", handle1)
CALL cp_fm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
a_first_col=a_first_col, &
Expand All @@ -150,9 +158,82 @@ SUBROUTINE cp_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
CALL dbcsr_release(b_db)
CALL dbcsr_release(c_db)
CALL timestop(handle1)
CASE (do_cosma)
#if defined(__COSMA)
CALL timeset("cp_gemm_cosma", handle1)
CALL offload_set_device()
CALL cosma_pdgemm(transa=transa, transb=transb, m=m, n=n, k=k, alpha=alpha, &
matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c)
CALL timestop(handle1)
#else
CPABORT("CP2K compiled without the COSMA library.")
#endif
END SELECT
CALL timestop(handle)
END SUBROUTINE cp_gemm
#if defined(__COSMA)
! **************************************************************************************************
!> \brief Fortran wrapper for cosma_pdgemm.
!> \param transa ...
!> \param transb ...
!> \param m ...
!> \param n ...
!> \param k ...
!> \param alpha ...
!> \param matrix_a ...
!> \param matrix_b ...
!> \param beta ...
!> \param matrix_c ...
!> \author Ole Schuett
! **************************************************************************************************
SUBROUTINE cosma_pdgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c)
CHARACTER(LEN=1), INTENT(IN) :: transa, transb
INTEGER, INTENT(IN) :: m, n, k
REAL(KIND=dp), INTENT(IN) :: alpha
TYPE(cp_fm_type), POINTER :: matrix_a, matrix_b
REAL(KIND=dp), INTENT(IN) :: beta
TYPE(cp_fm_type), POINTER :: matrix_c
INTERFACE
SUBROUTINE cosma_pdgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
b, ib, jb, descb, beta, c, ic, jc, descc) &
BIND(C, name="cosma_pdgemm")
IMPORT :: C_PTR, C_INT, C_DOUBLE, C_CHAR
CHARACTER(KIND=C_CHAR) :: transa
CHARACTER(KIND=C_CHAR) :: transb
INTEGER(KIND=C_INT) :: m
INTEGER(KIND=C_INT) :: n
INTEGER(KIND=C_INT) :: k
REAL(KIND=C_DOUBLE) :: alpha
TYPE(C_PTR), VALUE :: a
INTEGER(KIND=C_INT) :: ia
INTEGER(KIND=C_INT) :: ja
TYPE(C_PTR), VALUE :: desca
TYPE(C_PTR), VALUE :: b
INTEGER(KIND=C_INT) :: ib
INTEGER(KIND=C_INT) :: jb
TYPE(C_PTR), VALUE :: descb
REAL(KIND=C_DOUBLE) :: beta
TYPE(C_PTR), VALUE :: c
INTEGER(KIND=C_INT) :: ic
INTEGER(KIND=C_INT) :: jc
TYPE(C_PTR), VALUE :: descc
END SUBROUTINE cosma_pdgemm_c
END INTERFACE
CALL cosma_pdgemm_c(transa=transa, transb=transb, m=m, n=n, k=k, &
alpha=alpha, &
a=C_LOC(matrix_a%local_data(1, 1)), ia=1, ja=1, &
desca=C_LOC(matrix_a%matrix_struct%descriptor(1)), &
b=C_LOC(matrix_b%local_data(1, 1)), ib=1, jb=1, &
descb=C_LOC(matrix_b%matrix_struct%descriptor(1)), &
beta=beta, &
c=C_LOC(matrix_c%local_data(1, 1)), ic=1, jc=1, &
descc=C_LOC(matrix_c%matrix_struct%descriptor(1)))
END SUBROUTINE cosma_pdgemm
#endif
END MODULE cp_gemm_interface
21 changes: 17 additions & 4 deletions src/environment.F
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ MODULE environment
eps_check_diag_default
USE cp_fm_diag_utils, ONLY: cp_fm_redistribute_init
USE cp_fm_struct, ONLY: cp_fm_struct_config
USE cp_fm_types, ONLY: cp_fm_setup
USE cp_fm_types, ONLY: cp_fm_get_mm_type,&
cp_fm_setup
USE cp_log_handling, ONLY: &
cp_add_default_logger, cp_get_default_logger, cp_logger_create, &
cp_logger_get_default_unit_nr, cp_logger_release, cp_logger_set, cp_logger_type, &
Expand All @@ -57,9 +58,9 @@ MODULE environment
USE header, ONLY: cp2k_footer,&
cp2k_header
USE input_constants, ONLY: &
callgraph_all, callgraph_none, do_cp2k, do_diag_elpa, do_diag_scalapack, do_eip, &
do_farming, do_fft_fftw3, do_fft_sg, do_fist, do_qs, do_sirius, do_test, energy_run, &
id_development_version, mol_dyn_run, none_run
callgraph_all, callgraph_none, do_cosma, do_cp2k, do_dbcsr, do_diag_elpa, &
do_diag_scalapack, do_eip, do_farming, do_fft_fftw3, do_fft_sg, do_fist, do_qs, &
do_scalapack, do_sirius, do_test, energy_run, id_development_version, mol_dyn_run, none_run
USE input_cp2k_global, ONLY: create_global_section
USE input_enumeration_types, ONLY: enum_i2c,&
enumeration_type
Expand Down Expand Up @@ -789,6 +790,18 @@ SUBROUTINE read_global_section(root_section, para_env, globenv)
#endif
CALL section_release(section)

SELECT CASE (cp_fm_get_mm_type())
CASE (do_scalapack)
WRITE (UNIT=output_unit, FMT="(T2,A,T75,A6)") &
start_section_label//"| Matrix multiplication library", "SCALAPACK"
CASE (do_dbcsr)
WRITE (UNIT=output_unit, FMT="(T2,A,T75,A6)") &
start_section_label//"| Matrix multiplication library", "DBCSR"
CASE (do_cosma)
WRITE (UNIT=output_unit, FMT="(T2,A,T75,A6)") &
start_section_label//"| Matrix multiplication library", "COSMA"
END SELECT

CALL section_vals_val_get(global_section, "ALLTOALL_SGL", l_val=ata)
WRITE (UNIT=output_unit, FMT="(T2,A,T80,L1)") &
start_section_label//"| All-to-all communication in single precision", ata
Expand Down
5 changes: 3 additions & 2 deletions src/input_constants.F
Original file line number Diff line number Diff line change
Expand Up @@ -1085,8 +1085,9 @@ MODULE input_constants
sccs_derivative_cd7 = 3

! fm matrix multiplication
INTEGER, PARAMETER, PUBLIC :: do_pdgemm = 1, &
do_dbcsr = 2
INTEGER, PARAMETER, PUBLIC :: do_scalapack = 1, &
do_dbcsr = 2, &
do_cosma = 3

! Dispersion DFTB
INTEGER, PARAMETER, PUBLIC :: dispersion_uff = 100, &
Expand Down
33 changes: 21 additions & 12 deletions src/input_cp2k_global.F
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ MODULE input_cp2k_global
GRID_BACKEND_REF
USE input_constants, ONLY: &
bsse_run, callgraph_all, callgraph_master, callgraph_none, cell_opt_run, debug_run, &
do_atom, do_band, do_cp2k, do_dbcsr, do_diag_elpa, do_diag_scalapack, do_farming, &
do_fft_fftw3, do_fft_sg, do_opt_basis, do_optimize_input, do_pdgemm, do_swarm, do_tamc, &
do_test, do_tree_mc, do_tree_mc_ana, driver_run, ehrenfest, electronic_spectra_run, &
energy_force_run, energy_run, fftw_plan_estimate, fftw_plan_exhaustive, fftw_plan_measure, &
fftw_plan_patient, gaussian, geo_opt_run, linear_response_run, mol_dyn_run, mon_car_run, &
negf_run, none_run, pint_run, real_time_propagation, tree_mc_run, vib_anal
do_atom, do_band, do_cosma, do_cp2k, do_dbcsr, do_diag_elpa, do_diag_scalapack, &
do_farming, do_fft_fftw3, do_fft_sg, do_opt_basis, do_optimize_input, do_scalapack, &
do_swarm, do_tamc, do_test, do_tree_mc, do_tree_mc_ana, driver_run, ehrenfest, &
electronic_spectra_run, energy_force_run, energy_run, fftw_plan_estimate, &
fftw_plan_exhaustive, fftw_plan_measure, fftw_plan_patient, gaussian, geo_opt_run, &
linear_response_run, mol_dyn_run, mon_car_run, negf_run, none_run, pint_run, &
real_time_propagation, tree_mc_run, vib_anal
USE input_keyword_types, ONLY: keyword_create,&
keyword_release,&
keyword_type
Expand Down Expand Up @@ -593,6 +594,7 @@ END SUBROUTINE create_global_section
SUBROUTINE create_fm_section(section)
TYPE(section_type), POINTER :: section

INTEGER :: default_matmul
TYPE(keyword_type), POINTER :: keyword

CPASSERT(.NOT. ASSOCIATED(section))
Expand Down Expand Up @@ -624,18 +626,25 @@ SUBROUTINE create_fm_section(section)
CALL section_add_keyword(section, keyword)
CALL keyword_release(keyword)

#if defined(__COSMA)
default_matmul = do_cosma
#else
default_matmul = do_scalapack
#endif

CALL keyword_create(keyword, __LOCATION__, name="TYPE_OF_MATRIX_MULTIPLICATION", &
description="Allows to switch between scalapack pdgemm and dbcsr_multiply. "// &
"On normal systems pdgemm is recommended on system with GPU "// &
"is optimized and can give better performance. NOTE: if DBCSR is employed "// &
"FORCE_BLOCK_SIZE should be set. The performance on GPU's depends "// &
"crucially on the BLOCK_SIZES. Make sure optimized kernels are available.", &
usage="TYPE_OF_MATRIX_MULTIPLICATION ELPA", &
default_i_val=do_pdgemm, &
enum_i_vals=(/do_pdgemm, do_dbcsr/), &
enum_c_vals=s2a("PDGEMM", "DBCSR_MM"), &
enum_desc=s2a("Standard scalapack: pdgemm", &
"DBCSR_MM is employed. This needs local transformation of the matrices"))
default_i_val=default_matmul, &
enum_i_vals=(/do_scalapack, do_scalapack, do_dbcsr, do_cosma/), &
enum_c_vals=s2a("SCALAPACK", "PDGEMM", "DBCSR_MM", "COSMA"), &
enum_desc=s2a("Standard ScaLAPACK pdgemm", &
"Alias for SCALAPACK", &
"DBCSR_MM is employed. This needs local transformation of the matrices", &
"COSMA is employed. See https://github.com/eth-cscs/COSMA."))
CALL section_add_keyword(section, keyword)
CALL keyword_release(keyword)

Expand Down
3 changes: 3 additions & 0 deletions tools/docker/scripts/test_regtest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ if [[ "${ARCH}" == "local" ]]; then
fi
fi

# Improve code coverage on COSMA.
export COSMA_DIM_THRESHOLD=0

# Run regtests.
echo -e "\n========== Running Regtests =========="
make ARCH="${ARCH}" VERSION="${VERSION}" TESTOPTS="${TESTOPTS}" test
Expand Down

0 comments on commit 8ad318d

Please sign in to comment.