Skip to content

Commit

Permalink
Move metadata from cp_para_types and cp_blacs_env to the respective l…
Browse files Browse the repository at this point in the history
…ow-level types
  • Loading branch information
Frederick Stein authored and fstein93 committed Feb 28, 2023
1 parent cfb00a7 commit 066d5cd
Show file tree
Hide file tree
Showing 40 changed files with 297 additions and 275 deletions.
47 changes: 14 additions & 33 deletions src/common/cp_para_types.F
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ MODULE cp_para_types

! **************************************************************************************************
!> \brief stores all the informations relevant to an mpi environment
!> \param mepos rank of the actual processor
!> \param num_pe number of processors in the communicator
!> \param source rank of a special processor (for example the one for i-o,
!> or the master
!> \param owns_group if it owns the group (and thus should free it when
!> this object is deallocated)
!> \param ref_count the reference count, when it is zero this object gets
Expand All @@ -41,9 +37,8 @@ MODULE cp_para_types
! **************************************************************************************************
TYPE, EXTENDS(mp_comm_type) :: cp_para_env_type
PRIVATE
! We set it to true because to have less initialization steps in case we create a new communicator
! We set it to true to have less initialization steps in case we create a new communicator
LOGICAL :: owns_group = .TRUE.
INTEGER, PUBLIC :: mepos = -1, source = -1, num_pe = -1
INTEGER :: ref_count = -1
CONTAINS
PROCEDURE, PRIVATE, PASS(comm), NON_OVERRIDABLE :: mp_comm_init => cp_para_env_init
Expand All @@ -66,9 +61,9 @@ MODULE cp_para_types

! **************************************************************************************************
!> \brief represent a multidimensional parallel environment
!> \param mepos the position of the actual processor
!> \param num_pe number of processors in the group in each dimension
!> \param source id of a special processor (for example the one for i-o,
!> \param mepos_cart the position of the actual processor
!> \param num_pe_cart number of processors in the group in each dimension
!> \param source_cart id of a special processor (for example the one for i-o,
!> or the master
!> \param owns_group if it owns the group (and thus should free it when
!> this object is deallocated)
Expand All @@ -82,11 +77,8 @@ MODULE cp_para_types
! **************************************************************************************************
TYPE, EXTENDS(mp_cart_type) :: cp_para_cart_type
PRIVATE
! We set it to true because to have less initialization steps in case we create a new communicator
! We set it to true to have less initialization steps in case we create a new communicator
LOGICAL :: owns_group = .TRUE.
INTEGER, PUBLIC :: rank = -1, ntask = -1
INTEGER, DIMENSION(:), ALLOCATABLE, PUBLIC :: mepos, source, num_pe
LOGICAL, DIMENSION(:), ALLOCATABLE, PUBLIC :: periodic
INTEGER :: ref_count = -1
CONTAINS
PROCEDURE, PRIVATE, PASS(comm), NON_OVERRIDABLE :: mp_comm_free => cp_para_cart_free
Expand All @@ -105,12 +97,11 @@ ELEMENTAL IMPURE SUBROUTINE cp_para_env_init(comm, owns_group)
CLASS(cp_para_env_type), INTENT(INOUT) :: comm
LOGICAL, INTENT(IN), OPTIONAL :: owns_group

IF (PRESENT(owns_group)) comm%owns_group = owns_group
CALL comm%mp_comm_type%init()

comm%source = 0
IF (PRESENT(owns_group)) comm%owns_group = owns_group
comm%ref_count = 1
CALL comm%mp_comm_type%get_size(comm%num_pe)
CALL comm%mp_comm_type%get_rank(comm%mepos)

END SUBROUTINE cp_para_env_init

! **************************************************************************************************
Expand Down Expand Up @@ -172,23 +163,11 @@ ELEMENTAL IMPURE SUBROUTINE cp_para_cart_init(comm, owns_group)
CLASS(cp_para_cart_type), INTENT(INOUT) :: comm
LOGICAL, INTENT(IN), OPTIONAL :: owns_group

INTEGER :: ndims
CALL comm%mp_cart_type%init()

IF (PRESENT(owns_group)) comm%owns_group = owns_group
ndims = comm%get_ndims()

ALLOCATE (comm%source(ndims), comm%periodic(ndims), comm%mepos(ndims), &
comm%num_pe(ndims))

comm%source = 0
comm%mepos = 0
comm%periodic = .FALSE.
comm%ref_count = 1
comm%ntask = 1
CALL comm%get_info_cart(comm%num_pe, comm%mepos, &
comm%periodic)
CALL comm%get_size(comm%ntask)
CALL comm%get_rank(comm%rank)
END SUBROUTINE cp_para_cart_init

! **************************************************************************************************
Expand Down Expand Up @@ -223,9 +202,11 @@ SUBROUTINE cp_para_cart_free(comm)
comm%ref_count = comm%ref_count - 1
IF (comm%ref_count <= 0) THEN
IF (comm%owns_group) CALL comm%mp_cart_type%free()
DEALLOCATE (comm%source, comm%periodic, comm%mepos, comm%num_pe)
IF (comm%owns_group) THEN
CALL comm%mp_cart_type%free()
ELSE
DEALLOCATE (comm%periodic, comm%mepos_cart, comm%num_pe_cart)
END IF
END IF
END SUBROUTINE cp_para_cart_free
Expand Down
2 changes: 1 addition & 1 deletion src/common/parallel_rng_types_unittest.F
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ PROGRAM parallel_rng_types_TEST
END IF

CALL mp_world_init(mpi_comm)
CALL mpi_comm%get_rank(mepos)
mepos = mpi_comm%mepos
ionode = mepos == 0

CALL check_rng(default_output_unit, ionode)
Expand Down
8 changes: 4 additions & 4 deletions src/dbm/dbm_tests.F
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ SUBROUTINE dbm_run_tests(mp_group, io_unit, matrix_sizes, trs, &
CALL timeset(routineN, handle)
! Create MPI processor grid.
CALL mp_group%get_size(numnodes)
numnodes = mp_group%num_pe
npdims(:) = 0
CALL mp_dims_create(numnodes, npdims)
CALL cart_group%create(mp_group, 2, npdims)
CALL cart_group%get_info_cart(npdims)
npdims = cart_group%num_pe_cart
! Initialize random number generator.
randmat_counter = 12341313
Expand Down Expand Up @@ -221,7 +221,7 @@ SUBROUTINE run_multiply_test(matrix_a, matrix_b, matrix_c, transa, transb, alpha
CALL timeset(routineN, handle)
CALL group%get_size(numnodes)
numnodes = group%num_pe
CALL dbm_create_from_template(matrix_c_orig, "Original Matrix C", matrix_c)
CALL dbm_copy(matrix_c_orig, matrix_c)
Expand Down Expand Up @@ -288,7 +288,7 @@ SUBROUTINE fill_matrix(matrix, sparsity, group)
CALL timeset(routineN, handle)
CALL group%get_rank(mynode)
mynode = group%mepos
! Check that the counter was initialised (or has not overflowed)
CPASSERT(randmat_counter .NE. 0)
Expand Down
6 changes: 3 additions & 3 deletions src/dbt/dbt_io.F
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ SUBROUTINE dbt_write_tensor_dist(tensor, unit_nr)
unit_nr_prv = prep_output_unit(unit_nr)
IF (unit_nr_prv == 0) RETURN

CALL tensor%pgrid%mp_comm_2d%get_size(nproc)
nproc = tensor%pgrid%mp_comm_2d%num_pe

nblock = dbt_get_num_blocks(tensor)
nelement = dbt_get_nze(tensor)
Expand Down Expand Up @@ -190,7 +190,7 @@ SUBROUTINE dbt_write_blocks(tensor, io_unit_master, io_unit_all, write_int)
DO WHILE (dbt_iterator_blocks_left(iterator))
CALL dbt_iterator_next_block(iterator, blk_index, blk_size=blk_size)
CALL dbt_get_stored_coordinates(tensor, blk_index, proc)
CALL tensor%pgrid%mp_comm_2d%get_rank(mynode)
mynode = tensor%pgrid%mp_comm_2d%mepos
CPASSERT(proc .EQ. mynode)
#:for ndim in ndims
IF (ndims_tensor(tensor) == ${ndim}$) THEN
Expand Down Expand Up @@ -284,7 +284,7 @@ SUBROUTINE dbt_write_block_indices(tensor, io_unit_master, io_unit_all)
DO WHILE (dbt_iterator_blocks_left(iterator))
CALL dbt_iterator_next_block(iterator, blk_index, blk_size=blk_size)
CALL dbt_get_stored_coordinates(tensor, blk_index, proc)
CALL tensor%pgrid%mp_comm_2d%get_rank(mynode)
mynode = tensor%pgrid%mp_comm_2d%mepos
CPASSERT(proc .EQ. mynode)
#:for ndim in ndims
IF (ndims_tensor(tensor) == ${ndim}$) THEN
Expand Down
10 changes: 5 additions & 5 deletions src/dbt/dbt_methods.F
Original file line number Diff line number Diff line change
Expand Up @@ -956,15 +956,15 @@ SUBROUTINE dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
mp_comm_opt = dbt_tas_mp_comm(tensor_small%pgrid%mp_comm_2d, split_opt%split_rowcol, NINT(storage%nsplit_avg))
CALL dbt_tas_create_split(split_opt_avg, mp_comm_opt, split_opt%split_rowcol, &
NINT(storage%nsplit_avg), own_comm=.TRUE.)
CALL split_opt_avg%mp_comm%get_info_cart(pdims_2d_opt)
pdims_2d_opt = split_opt_avg%mp_comm%num_pe_cart
END IF
END ASSOCIATE
IF (do_change_pgrid(1) .AND. .NOT. do_change_pgrid(2)) THEN
! check if new grid has better subgrid, if not there is no need to change process grid
CALL split_opt_avg%mp_comm_group%get_info_cart(pdims_sub_opt)
CALL split%mp_comm_group%get_info_cart(pdims_sub)
pdims_sub_opt = split_opt_avg%mp_comm_group%num_pe_cart
pdims_sub = split%mp_comm_group%num_pe_cart
pdim_ratio = MAXVAL(REAL(pdims_sub, dp))/MINVAL(pdims_sub)
pdim_ratio_opt = MAXVAL(REAL(pdims_sub_opt, dp))/MINVAL(pdims_sub_opt)
Expand Down Expand Up @@ -1484,7 +1484,7 @@ FUNCTION update_contraction_storage(storage, split_opt, split) RESULT(do_change_
nsplit_opt = split_opt%ngroup_opt
nsplit = split%ngroup

CALL split%mp_comm%get_info_cart(pdims)
pdims = split%mp_comm%num_pe_cart

storage%ibatch = storage%ibatch + 1

Expand All @@ -1501,7 +1501,7 @@ FUNCTION update_contraction_storage(storage, split_opt, split) RESULT(do_change_
do_change_pgrid(:) = .FALSE.

! check for process grid dimensions
CALL split%mp_comm_group%get_info_cart(pdims_sub)
pdims_sub = split%mp_comm_group%num_pe_cart
change_criterion = MAXVAL(REAL(pdims_sub, dp))/MINVAL(pdims_sub)
IF (change_criterion > default_pdims_accept_ratio**2) do_change_pgrid(1) = .TRUE.

Expand Down
4 changes: 2 additions & 2 deletions src/dbt/dbt_reshape_ops.F
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ SUBROUTINE dbt_reshape(tensor_in, tensor_out, summation, move_data)
IF (.NOT. summation_prv) CALL dbt_clear(tensor_out)

mp_comm = tensor_in%pgrid%mp_comm_2d
CALL mp_comm%get_size(numnodes)
numnodes = mp_comm%num_pe
ALLOCATE (buffer_send(0:numnodes - 1), buffer_recv(0:numnodes - 1))
ALLOCATE (nblks_send_total(0:numnodes - 1), ndata_send_total(0:numnodes - 1), source=0)
ALLOCATE (nblks_recv_total(0:numnodes - 1), ndata_recv_total(0:numnodes - 1), source=0)
Expand Down Expand Up @@ -235,7 +235,7 @@ SUBROUTINE dbt_communicate_buffer(mp_comm, buffer_recv, buffer_send)
INTEGER :: handle

CALL timeset(routineN, handle)
CALL mp_comm%get_size(numnodes)
numnodes = mp_comm%num_pe

IF (numnodes > 1) THEN
!$OMP MASTER
Expand Down
11 changes: 5 additions & 6 deletions src/dbt/dbt_test.F
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ MODULE dbt_test
USE dbt_block, ONLY: block_nd
USE dbt_types, ONLY: &
dbt_create, dbt_destroy, dbt_type, dbt_distribution_type, &
dbt_distribution_destroy, &
dims_tensor, ndims_tensor, dbt_distribution_new, dbt_nd_mp_comm, &
dbt_distribution_destroy, dims_tensor, ndims_tensor, dbt_distribution_new, &
mp_environ_pgrid, dbt_pgrid_type, dbt_pgrid_create, dbt_pgrid_destroy, dbt_get_info, &
dbt_default_distvec
USE dbt_io, ONLY: &
Expand Down Expand Up @@ -201,7 +200,7 @@ SUBROUTINE dbt_test_formats(ndims, mp_comm, unit_nr, verbose, &
! Process grid
pdims(:) = 0
CALL dbt_pgrid_create(mp_comm, pdims, comm_nd)
CALL mp_comm%get_rank(mynode)
mynode = mp_comm%mepos

io_unit = 0
IF (mynode .EQ. 0) io_unit = unit_nr
Expand Down Expand Up @@ -259,7 +258,7 @@ SUBROUTINE dbt_test_formats(ndims, mp_comm, unit_nr, verbose, &
ALLOCATE (map1, source=perm(1:isep, iperm))
ALLOCATE (map2, source=perm(isep + 1:ndims, iperm))

CALL mp_comm%get_rank(mynode)
mynode = mp_comm%mepos
CALL mp_environ_pgrid(comm_nd, pdims, myploc)

#:for dim in range(1, maxdim+1)
Expand Down Expand Up @@ -387,7 +386,7 @@ SUBROUTINE dbt_setup_test_tensor(tensor, mp_comm, enumerate, ${varlist("blk_ind"
INTEGER, DIMENSION(2) :: blk_index_2d, nblks_2d

nblks_alloc = SIZE(blk_ind_1)
CALL mp_comm%get_rank(mynode)
mynode = mp_comm%mepos

IF (.NOT. enumerate) THEN
CPASSERT(randmat_counter .NE. 0)
Expand Down Expand Up @@ -619,7 +618,7 @@ SUBROUTINE dbt_contract_test(alpha, tensor_1, tensor_2, beta, tensor_3, &
LOGICAL :: do_crop_1, do_crop_2

mp_comm = tensor_1%pgrid%mp_comm_2d
CALL mp_comm%get_rank(mynode)
mynode = mp_comm%mepos
io_unit = -1
IF (mynode .EQ. 0) io_unit = unit_nr

Expand Down
21 changes: 14 additions & 7 deletions src/dbt/dbt_types.F
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ SUBROUTINE dbt_pgrid_create_expert(mp_comm, dims, pgrid, map1_2d, map2_2d, tenso

ndims = SIZE(dims)

CALL mp_comm%get_size(nproc)
nproc = mp_comm%num_pe
IF (ANY(dims == 0)) THEN
IF (.NOT. PRESENT(tensor_dims)) THEN
CALL mp_dims_create(nproc, dims)
Expand Down Expand Up @@ -558,7 +558,7 @@ SUBROUTINE dbt_pgrid_create_expert(mp_comm, dims, pgrid, map1_2d, map2_2d, tenso
! **************************************************************************************************
FUNCTION dbt_nd_mp_comm(comm_2d, map1_2d, map2_2d, dims_nd, dims1_nd, dims2_nd, pdims_2d, tdims, &
nsplit, dimsplit)
TYPE(mp_cart_type), INTENT(IN) :: comm_2d
CLASS(mp_comm_type), INTENT(IN) :: comm_2d
INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d
INTEGER, DIMENSION(SIZE(map1_2d) + SIZE(map2_2d)), &
INTENT(IN), OPTIONAL :: dims_nd
Expand All @@ -585,7 +585,15 @@ FUNCTION dbt_nd_mp_comm(comm_2d, map1_2d, map2_2d, dims_nd, dims1_nd, dims2_nd,
IF (PRESENT(pdims_2d)) THEN
dims_2d(:) = pdims_2d
ELSE
CALL comm_2d%get_info_cart(dims_2d)
! This branch allows us to call this routine with a plain mp_comm_type without actually requiring an mp_cart_type
! In a few cases in CP2K, this prevents erroneous calls to mpi_cart_get with a non-cartesian communicator
SELECT TYPE (comm_2d)
CLASS IS (mp_cart_type)
dims_2d = comm_2d%num_pe_cart
CLASS DEFAULT
CALL cp_abort(__LOCATION__, "If the argument pdims_2d is not given, the "// &
"communicator comm_2d must be of class mp_cart_type.")
END SELECT
END IF

IF (.NOT. PRESENT(dims_nd)) THEN
Expand Down Expand Up @@ -670,7 +678,7 @@ SUBROUTINE mp_environ_pgrid(pgrid, dims, task_coor)
INTEGER, DIMENSION(ndims_mapping(pgrid%nd_index_grid)), INTENT(OUT) :: task_coor
INTEGER, DIMENSION(2) :: task_coor_2d

CALL pgrid%mp_comm_2d%get_info_cart(task_coor=task_coor_2d)
task_coor_2d = pgrid%mp_comm_2d%mepos_cart
CALL dbt_get_mapping_info(pgrid%nd_index_grid, dims_nd=dims)
task_coor = get_nd_indices_pgrid(pgrid%nd_index_grid, task_coor_2d)
END SUBROUTINE
Expand Down Expand Up @@ -743,7 +751,7 @@ SUBROUTINE dbt_distribution_new_expert(dist, pgrid, map1_2d, map2_2d, ${varlist(

comm_2d = pgrid_prv%mp_comm_2d

CALL comm_2d%get_info_cart(pdims_2d_check)
pdims_2d_check = comm_2d%num_pe_cart
IF (ANY(pdims_2d_check .NE. pdims_2d)) THEN
CPABORT("inconsistent process grid dimensions")
END IF
Expand Down Expand Up @@ -1183,7 +1191,7 @@ SUBROUTINE dbt_create_matrix(matrix_in, tensor, order, name)

CHARACTER(len=default_string_length) :: name_in
INTEGER, DIMENSION(2) :: order_in
TYPE(mp_cart_type) :: comm_2d
TYPE(mp_comm_type) :: comm_2d
TYPE(dbcsr_distribution_type) :: matrix_dist
TYPE(dbt_distribution_type) :: dist
INTEGER, DIMENSION(:), POINTER :: row_blk_size, col_blk_size
Expand Down Expand Up @@ -1637,7 +1645,6 @@ PURE FUNCTION dbt_max_nblks_local(tensor) RESULT(blk_count)
blk_count_total = PRODUCT(INT(bdims, int_8))

! can not call an MPI routine due to PURE
!CALL tensor%pgrid%mp_comm_2d%get_size(nproc)
nproc = tensor%pgrid%nproc

blk_count = INT(blk_count_total/nproc*max_load_imbalance)
Expand Down
2 changes: 1 addition & 1 deletion src/dbt/dbt_unittest.F
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ PROGRAM dbt_unittest
INTEGER, DIMENSION(:, :), ALLOCATABLE :: bounds, bounds_1, bounds_2

CALL mp_world_init(mp_comm)
CALL mp_comm%get_rank(mynode)
mynode = mp_comm%mepos

! Select active offload device when available.
IF (offload_get_device_count() > 0) THEN
Expand Down
2 changes: 1 addition & 1 deletion src/dbt/tas/dbt_tas_base.F
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ SUBROUTINE dbt_tas_convert_to_tas(info, matrix_rect, matrix_dbm)
NULLIFY (col_blk_size, row_blk_size)
CALL timeset(routineN, handle)
CALL info%mp_comm%get_info_cart(pdims)
pdims = info%mp_comm%num_pe_cart
name = dbm_get_name(matrix_dbm)
row_blk_size => dbm_get_row_block_sizes(matrix_dbm)
Expand Down
8 changes: 4 additions & 4 deletions src/dbt/tas/dbt_tas_io.F
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ SUBROUTINE dbt_tas_write_dist(matrix, unit_nr, full_info)

CALL dbt_tas_get_split_info(matrix%dist%info, mp_comm, ngroup, igroup, mp_comm_group, split_rowcol)
CALL dbt_tas_get_info(matrix, name=name)
CALL mp_comm%get_size(nproc)
nproc = mp_comm%num_pe

nblock = dbt_tas_get_num_blocks(matrix)
nelement = dbt_tas_get_nze(matrix)
Expand Down Expand Up @@ -232,9 +232,9 @@ SUBROUTINE dbt_tas_write_split_info(info, unit_nr, name)

CALL dbt_tas_get_split_info(info, mp_comm, nsplit, igroup, mp_comm_group, split_rowcol, pgrid_offset)

CALL mp_comm%get_rank(mynode)
CALL mp_comm%get_info_cart(dims)
CALL mp_comm_group%get_info_cart(groupdims)
mynode = mp_comm%mepos
dims = mp_comm%num_pe_cart
groupdims = mp_comm_group%num_pe_cart

IF (unit_nr_prv > 0) THEN
SELECT CASE (split_rowcol)
Expand Down

0 comments on commit 066d5cd

Please sign in to comment.