Skip to content

Commit

Permalink
Switch to mpi_f08
Browse files Browse the repository at this point in the history
  • Loading branch information
Frederick Stein authored and fstein93 committed Jan 10, 2023
1 parent cdc973b commit 9b6bceb
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 63 deletions.
133 changes: 95 additions & 38 deletions src/mpiwrap/message_passing.F
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ MODULE message_passing
#include "../base/base_uses.f90"

#if defined(__parallel)
USE mpi
USE mpi_f08
! subroutines: unfortunately, mpi implementations do not provide interfaces for all subroutines (problems with types and ranks explosion),
! we do not quite know what is in the module, so we can not include any....
! to nevertheless get checking for what is included, we use the mpi module without use clause, getting all there is
Expand Down Expand Up @@ -59,20 +59,21 @@ MODULE message_passing

! parameters that might be needed
#if defined(__parallel)
INTEGER, PARAMETER :: MP_STD_REAL = MPI_DOUBLE_PRECISION
INTEGER, PARAMETER :: MP_STD_COMPLEX = MPI_DOUBLE_COMPLEX
INTEGER, PARAMETER :: MP_STD_HALF_REAL = MPI_REAL
INTEGER, PARAMETER :: MP_STD_HALF_COMPLEX = MPI_COMPLEX
TYPE(MPI_Datatype), PARAMETER :: MP_STD_REAL = MPI_DOUBLE_PRECISION
TYPE(MPI_Datatype), PARAMETER :: MP_STD_COMPLEX = MPI_DOUBLE_COMPLEX
TYPE(MPI_Datatype), PARAMETER :: MP_STD_HALF_REAL = MPI_REAL
TYPE(MPI_Datatype), PARAMETER :: MP_STD_HALF_COMPLEX = MPI_COMPLEX

LOGICAL, PARAMETER :: cp2k_is_parallel = .TRUE.
INTEGER, PARAMETER, PUBLIC :: mp_any_tag = MPI_ANY_TAG
INTEGER, PARAMETER, PUBLIC :: mp_any_source = MPI_ANY_SOURCE
INTEGER, PARAMETER :: mp_comm_null_handle = MPI_COMM_NULL
INTEGER, PARAMETER :: mp_comm_self_handle = MPI_COMM_SELF
INTEGER, PARAMETER :: mp_comm_world_handle = MPI_COMM_WORLD
INTEGER, PARAMETER :: mp_request_null_handle = MPI_REQUEST_NULL
INTEGER, PARAMETER :: mp_win_null_handle = MPI_WIN_NULL
INTEGER, PARAMETER :: mp_file_null_handle = MPI_FILE_NULL
TYPE(MPI_COMM), PARAMETER :: mp_comm_null_handle = MPI_COMM_NULL
TYPE(MPI_COMM), PARAMETER :: mp_comm_self_handle = MPI_COMM_SELF
TYPE(MPI_COMM), PARAMETER :: mp_comm_world_handle = MPI_COMM_WORLD
TYPE(MPI_REQUEST), PARAMETER :: mp_request_null_handle = MPI_REQUEST_NULL
TYPE(MPI_WIN), PARAMETER :: mp_win_null_handle = MPI_WIN_NULL
TYPE(MPI_FILE), PARAMETER :: mp_file_null_handle = MPI_FILE_NULL
TYPE(MPI_Info), PARAMETER :: mp_info_null_handle = MPI_INFO_NULL
INTEGER, PARAMETER, PUBLIC :: mp_status_size = MPI_STATUS_SIZE
INTEGER, PARAMETER, PUBLIC :: mp_proc_null = MPI_PROC_NULL
! Set max allocatable memory by MPI to 2 GiByte
Expand All @@ -98,8 +99,9 @@ MODULE message_passing
INTEGER, PARAMETER :: mp_request_null_handle = -4
INTEGER, PARAMETER :: mp_win_null_handle = -5
INTEGER, PARAMETER :: mp_file_null_handle = -6
INTEGER, PARAMETER, PUBLIC :: mp_status_size = -7
INTEGER, PARAMETER, PUBLIC :: mp_proc_null = -8
INTEGER, PARAMETER :: mp_info_null_handle = -7
INTEGER, PARAMETER, PUBLIC :: mp_status_size = -8
INTEGER, PARAMETER, PUBLIC :: mp_proc_null = -9
INTEGER, PARAMETER, PUBLIC :: mp_max_library_version_string = 1

INTEGER, PARAMETER, PUBLIC :: file_offset = int_8
Expand Down Expand Up @@ -128,10 +130,15 @@ MODULE message_passing
PUBLIC :: mp_request_type
PUBLIC :: mp_win_type
PUBLIC :: mp_file_type
PUBLIC :: mp_info_type

TYPE mp_comm_type
PRIVATE
#if defined(__parallel)
TYPE(MPI_COMM) :: handle = mp_comm_null_handle
#else
INTEGER :: handle = mp_comm_null_handle
#endif
CONTAINS
PROCEDURE :: set_handle => mp_comm_type_set_handle
PROCEDURE :: get_handle => mp_comm_type_get_handle
Expand All @@ -143,7 +150,11 @@ MODULE message_passing

TYPE mp_request_type
PRIVATE
#if defined(__parallel)
TYPE(MPI_REQUEST) :: handle = mp_request_null_handle
#else
INTEGER :: handle = mp_request_null_handle
#endif
CONTAINS
PROCEDURE :: set_handle => mp_request_type_set_handle
PROCEDURE :: get_handle => mp_request_type_get_handle
Expand All @@ -155,7 +166,11 @@ MODULE message_passing

TYPE mp_win_type
PRIVATE
#if defined(__parallel)
TYPE(MPI_WIN) :: handle = mp_win_null_handle
#else
INTEGER :: handle = mp_win_null_handle
#endif
CONTAINS
PROCEDURE :: set_handle => mp_win_type_set_handle
PROCEDURE :: get_handle => mp_win_type_get_handle
Expand All @@ -167,7 +182,11 @@ MODULE message_passing

TYPE mp_file_type
PRIVATE
#if defined(__parallel)
TYPE(MPI_FILE) :: handle = mp_file_null_handle
#else
INTEGER :: handle = mp_file_null_handle
#endif
CONTAINS
PROCEDURE :: set_handle => mp_file_type_set_handle
PROCEDURE :: get_handle => mp_file_type_get_handle
Expand All @@ -177,13 +196,30 @@ MODULE message_passing
GENERIC, PUBLIC :: OPERATOR(.NE.) => mp_file_op_neq
END TYPE

TYPE mp_info_type
PRIVATE
#if defined(__parallel)
TYPE(MPI_Info) :: handle = mp_info_null_handle
#else
INTEGER :: handle = mp_info_null_handle
#endif
CONTAINS
PROCEDURE :: set_handle => mp_info_type_set_handle
PROCEDURE :: get_handle => mp_info_type_get_handle
PROCEDURE, PRIVATE :: mp_info_op_eq
PROCEDURE, PRIVATE :: mp_info_op_neq
GENERIC, PUBLIC :: OPERATOR(.EQ.) => mp_info_op_eq
GENERIC, PUBLIC :: OPERATOR(.NE.) => mp_info_op_neq
END TYPE

! Create the constants from the corresponding handles
TYPE(mp_comm_type), PARAMETER, PUBLIC :: mp_comm_null = mp_comm_type(mp_comm_null_handle)
TYPE(mp_comm_type), PARAMETER, PUBLIC :: mp_comm_self = mp_comm_type(mp_comm_self_handle)
TYPE(mp_comm_type), PARAMETER, PUBLIC :: mp_comm_world = mp_comm_type(mp_comm_world_handle)
TYPE(mp_request_type), PARAMETER, PUBLIC :: mp_request_null = mp_request_type(mp_request_null_handle)
TYPE(mp_win_type), PARAMETER, PUBLIC :: mp_win_null = mp_win_type(mp_win_null_handle)
TYPE(mp_file_type), PARAMETER, PUBLIC :: mp_file_null = mp_file_type(mp_file_null_handle)
TYPE(mp_info_type), PARAMETER, PUBLIC :: mp_info_null = mp_info_type(mp_info_null_handle)

! init and error
PUBLIC :: mp_world_init, mp_world_finalize
Expand Down Expand Up @@ -688,7 +724,11 @@ MODULE message_passing
END TYPE mp_indexing_meta_type

TYPE mp_type_descriptor_type
#if defined(__parallel)
TYPE(MPI_Datatype) :: type_handle
#else
INTEGER :: type_handle
#endif
INTEGER :: length
#if defined(__parallel)
INTEGER(kind=mpi_address_kind) :: base
Expand All @@ -712,7 +752,11 @@ MODULE message_passing
END TYPE mp_file_indexing_meta_type

TYPE mp_file_descriptor_type
#if defined(__parallel)
TYPE(MPI_Datatype) :: type_handle
#else
INTEGER :: type_handle
#endif
INTEGER :: length
LOGICAL :: has_indexing = .FALSE.
TYPE(mp_file_indexing_meta_type) :: index_descriptor
Expand Down Expand Up @@ -769,7 +813,7 @@ MODULE message_passing
CONTAINS

#:mute
#:set types = ["comm", "request", "win", "file"]
#:set types = ["comm", "request", "win", "file", "info"]
#:endmute
#:for type in types
LOGICAL FUNCTION mp_${type}$_op_eq(${type}$1, ${type}$2)
Expand All @@ -786,14 +830,22 @@ ELEMENTAL SUBROUTINE mp_${type}$_type_set_handle(this, handle)
CLASS(mp_${type}$_type), INTENT(INOUT) :: this
INTEGER, INTENT(IN) :: handle

#if defined(__parallel)
this%handle%mpi_val = handle
#else
this%handle = handle
#endif
END SUBROUTINE mp_${type}$_type_set_handle

ELEMENTAL FUNCTION mp_${type}$_type_get_handle(this) RESULT(handle)
CLASS(mp_${type}$_type), INTENT(IN) :: this
INTEGER :: handle

#if defined(__parallel)
handle = this%handle%mpi_val
#else
handle = this%handle
#endif
END FUNCTION mp_${type}$_type_get_handle
#:endfor

Expand Down Expand Up @@ -869,7 +921,7 @@ SUBROUTINE mp_reordering(mp_comm, mp_new_comm, ranks_order)
INTEGER :: handle, ierr
#if defined(__parallel)
TYPE(mp_comm_type) :: newcomm
INTEGER :: newgroup, oldgroup
TYPE(MPI_Group) :: newgroup, oldgroup
#endif

CALL mp_timeset(routineN, handle)
Expand Down Expand Up @@ -1593,8 +1645,9 @@ SUBROUTINE mp_rank_compare(comm1, comm2, rank)

INTEGER :: handle, ierr
#if defined(__parallel)
INTEGER :: g1, g2, i, n, n1, n2
INTEGER :: i, n, n1, n2
INTEGER, ALLOCATABLE, DIMENSION(:) :: rin
TYPE(MPI_Group) :: g1, g2
#endif

ierr = 0
Expand Down Expand Up @@ -1744,15 +1797,15 @@ SUBROUTINE mp_waitall_1(requests)
INTEGER :: handle, ierr
#if defined(__parallel)
INTEGER :: count
INTEGER, ALLOCATABLE, DIMENSION(:, :) :: status
TYPE(MPI_Status), ALLOCATABLE, DIMENSION(:) :: status
#endif

ierr = 0
CALL mp_timeset(routineN, handle)

#if defined(__parallel)
count = SIZE(requests)
ALLOCATE (status(MPI_STATUS_SIZE, count))
ALLOCATE (status(count))
CALL mpi_waitall_internal(count, requests, status, ierr) ! MPI_STATUSES_IGNORE openmpi workaround
IF (ierr /= 0) CALL mp_stop(ierr, "mpi_waitall @ mp_waitall_1")
DEALLOCATE (status)
Expand All @@ -1778,15 +1831,15 @@ SUBROUTINE mp_waitall_2(requests)
INTEGER :: handle, ierr
#if defined(__parallel)
INTEGER :: count
INTEGER, ALLOCATABLE, DIMENSION(:, :) :: status
TYPE(MPI_Status), ALLOCATABLE, DIMENSION(:) :: status
#endif

ierr = 0
CALL mp_timeset(routineN, handle)

#if defined(__parallel)
count = SIZE(requests)
ALLOCATE (status(MPI_STATUS_SIZE, count))
ALLOCATE (status(count))

CALL mpi_waitall_internal(count, requests, status, ierr) ! MPI_STATUSES_IGNORE openmpi workaround
IF (ierr /= 0) CALL mp_stop(ierr, "mpi_waitall @ mp_waitall_2")
Expand All @@ -1812,12 +1865,12 @@ END SUBROUTINE mp_waitall_2
SUBROUTINE mpi_waitall_internal(count, array_of_requests, array_of_statuses, ierr)
INTEGER, INTENT(in) :: count
TYPE(mp_request_type), DIMENSION(count), INTENT(inout) :: array_of_requests
INTEGER, DIMENSION(MPI_STATUS_SIZE, *), &
TYPE(MPI_Status), DIMENSION(*), &
INTENT(out) :: array_of_statuses
INTEGER, INTENT(out) :: ierr

INTEGER :: i
INTEGER, ALLOCATABLE, DIMENSION(:) :: request_handles
TYPE(MPI_Request), ALLOCATABLE, DIMENSION(:) :: request_handles

ALLOCATE (request_handles(count))
DO i = 1, count
Expand Down Expand Up @@ -1850,7 +1903,7 @@ SUBROUTINE mp_waitany(requests, completed)
INTEGER :: handle, ierr
#if defined(__parallel)
INTEGER :: count, i
INTEGER, ALLOCATABLE, DIMENSION(:) :: request_handles
TYPE(MPI_Request), ALLOCATABLE, DIMENSION(:) :: request_handles
#endif

ierr = 0
Expand Down Expand Up @@ -2030,11 +2083,11 @@ SUBROUTINE mpi_testany_internal(count, array_of_requests, index, flag, status, i
TYPE(mp_request_type), DIMENSION(count), INTENT(inout) :: array_of_requests
INTEGER, INTENT(out) :: index
LOGICAL, INTENT(out) :: flag
INTEGER, DIMENSION(MPI_STATUS_SIZE), INTENT(out) :: status
TYPE(MPI_Status), INTENT(out) :: status
INTEGER, INTENT(out) :: ierr
INTEGER :: i
INTEGER, ALLOCATABLE, DIMENSION(:) :: request_handles
TYPE(MPI_Request), ALLOCATABLE, DIMENSION(:) :: request_handles
ALLOCATE (request_handles(count))
DO i = 1, count
Expand Down Expand Up @@ -2272,7 +2325,7 @@ SUBROUTINE mp_probe(source, comm, tag)
INTEGER :: handle, ierr
#if defined(__parallel)
INTEGER, DIMENSION(mp_status_size) :: status_single
TYPE(MPI_Status) :: status_single
LOGICAL :: flag
#endif
Expand All @@ -2285,8 +2338,8 @@ SUBROUTINE mp_probe(source, comm, tag)
IF (source .EQ. mp_any_source) THEN
CALL mpi_probe(mp_any_source, mp_any_tag, comm%handle, status_single, ierr)
IF (ierr /= 0) CALL mp_stop(ierr, "mpi_probe @ mp_probe")
source = status_single(MPI_SOURCE)
tag = status_single(MPI_TAG)
source = status_single%MPI_SOURCE
tag = status_single%MPI_TAG
ELSE
flag = .FALSE.
CALL mpi_iprobe(source, mp_any_tag, comm%handle, flag, status_single, ierr)
Expand All @@ -2295,7 +2348,7 @@ SUBROUTINE mp_probe(source, comm, tag)
source = mp_any_source
tag = -1 !status_single(MPI_TAG) ! in case of flag==false status is undefined
ELSE
tag = status_single(MPI_TAG)
tag = status_single%MPI_TAG
END IF
END IF
#else
Expand Down Expand Up @@ -3213,11 +3266,11 @@ SUBROUTINE mp_file_open(groupid, fh, filepath, amode_status, info)
TYPE(mp_file_type), INTENT(OUT) :: fh
CHARACTER(len=*), INTENT(IN) :: filepath
INTEGER, INTENT(IN) :: amode_status
INTEGER, INTENT(IN), OPTIONAL :: info
TYPE(mp_info_type), INTENT(IN), OPTIONAL :: info
INTEGER :: ierr, istat
#if defined(__parallel)
INTEGER :: my_info
TYPE(MPI_Info) :: my_info
#else
CHARACTER(LEN=10) :: fstatus, fposition
INTEGER :: amode, handle
Expand All @@ -3228,7 +3281,7 @@ SUBROUTINE mp_file_open(groupid, fh, filepath, amode_status, info)
istat = 0
#if defined(__parallel)
my_info = mpi_info_null
IF (PRESENT(info)) my_info = info
IF (PRESENT(info)) my_info = info%handle
CALL mpi_file_open(groupid%handle, filepath, amode_status, my_info, fh%handle, ierr)
CALL mpi_file_set_errhandler(fh%handle, MPI_ERRORS_RETURN, ierr)
IF (ierr .NE. 0) CALL mp_stop(ierr, "mpi_file_set_errhandler @ mp_file_open")
Expand Down Expand Up @@ -3269,18 +3322,18 @@ END SUBROUTINE mp_file_open
! **************************************************************************************************
SUBROUTINE mp_file_delete(filepath, info)
CHARACTER(len=*), INTENT(IN) :: filepath
INTEGER, INTENT(IN), OPTIONAL :: info
TYPE(mp_info_type), INTENT(IN), OPTIONAL :: info
#if defined(__parallel)
INTEGER :: ierr
INTEGER :: my_info
TYPE(MPI_Info) :: my_info
LOGICAL :: exists
#endif
#if defined(__parallel)
ierr = 0
my_info = mpi_info_null
IF (PRESENT(info)) my_info = info
IF (PRESENT(info)) my_info = info%handle
INQUIRE (FILE=filepath, EXIST=exists)
IF (exists) CALL mpi_file_delete(filepath, my_info, ierr)
IF (ierr .NE. 0) CALL mp_stop(ierr, "mpi_file_set_errhandler @ mp_file_delete")
Expand Down Expand Up @@ -3650,11 +3703,14 @@ FUNCTION mp_type_make_struct(subtypes, &
CHARACTER(len=*), PARAMETER :: routineN = 'mp_type_make_struct'
INTEGER :: i, ierr, n
INTEGER, ALLOCATABLE, DIMENSION(:) :: lengths
#if defined(__parallel)
INTEGER(kind=mpi_address_kind), &
ALLOCATABLE, DIMENSION(:) :: displacements
TYPE(MPI_Datatype), ALLOCATABLE, DIMENSION(:) :: old_types
#else
INTEGER, ALLOCATABLE, DIMENSION(:) :: old_types
#endif
INTEGER, ALLOCATABLE, DIMENSION(:) :: lengths, old_types
ierr = 0
n = SIZE(subtypes)
Expand Down Expand Up @@ -3764,7 +3820,6 @@ FUNCTION mp_file_type_hindexed_make_chv(count, lengths, displs) &

ierr = 0
CALL mp_timeset(routineN, handle)
type_descriptor%type_handle = 0

#if defined(__parallel)
CALL MPI_Type_create_hindexed(count, lengths, INT(displs, KIND=address_kind), MPI_CHARACTER, &
Expand Down Expand Up @@ -3944,9 +3999,11 @@ SUBROUTINE mp_file_type_free(type_descriptor)
CALL MPI_Type_free(type_descriptor%type_handle, ierr)
IF (ierr /= 0) &
CPABORT("MPI_Type_free @ "//routineN)
type_descriptor%type_handle%mpi_val = -1
#else
type_descriptor%type_handle = -1
#endif
type_descriptor%length = -1
type_descriptor%type_handle = -1
IF (type_descriptor%has_indexing) THEN
NULLIFY (type_descriptor%index_descriptor%index)
NULLIFY (type_descriptor%index_descriptor%chunks)
Expand Down

0 comments on commit 9b6bceb

Please sign in to comment.