From 2dec57d489544588994b8dfb8ec20edc1c98bc67 Mon Sep 17 00:00:00 2001 From: "hasegawa.kento" Date: Thu, 21 Aug 2025 14:33:49 +0900 Subject: [PATCH 1/5] COLL/UCC: renaming _init to avoid naming conflicts Signed-off-by: hasegawa.kento --- ompi/mca/coll/ucc/coll_ucc.h | 74 +++++++++++++++++++ ompi/mca/coll/ucc/coll_ucc_allgather.c | 7 +- ompi/mca/coll/ucc/coll_ucc_allgatherv.c | 7 +- ompi/mca/coll/ucc/coll_ucc_allreduce.c | 7 +- ompi/mca/coll/ucc/coll_ucc_alltoall.c | 7 +- ompi/mca/coll/ucc/coll_ucc_alltoallv.c | 7 +- ompi/mca/coll/ucc/coll_ucc_barrier.c | 7 +- ompi/mca/coll/ucc/coll_ucc_bcast.c | 7 +- ompi/mca/coll/ucc/coll_ucc_gather.c | 7 +- ompi/mca/coll/ucc/coll_ucc_gatherv.c | 7 +- ompi/mca/coll/ucc/coll_ucc_reduce.c | 7 +- ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c | 7 +- .../coll/ucc/coll_ucc_reduce_scatter_block.c | 7 +- ompi/mca/coll/ucc/coll_ucc_scatter.c | 7 +- ompi/mca/coll/ucc/coll_ucc_scatterv.c | 7 +- 15 files changed, 130 insertions(+), 42 deletions(-) diff --git a/ompi/mca/coll/ucc/coll_ucc.h b/ompi/mca/coll/ucc/coll_ucc.h index 510e4796448..7a9a5b4380d 100644 --- a/ompi/mca/coll/ucc/coll_ucc.h +++ b/ompi/mca/coll/ucc/coll_ucc.h @@ -1,6 +1,7 @@ /** Copyright (c) 2021 Mellanox Technologies. All rights reserved. Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + Copyright (c) 2025 Fujitsu Limited. All rights reserved. $COPYRIGHT$ Additional copyrights may follow @@ -305,5 +306,78 @@ int mca_coll_ucc_iscatter(const void *sbuf, size_t scount, ompi_request_t** request, mca_coll_base_module_t *module); +int mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_barrier_init(struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_bcast_init(void *buff, size_t count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_alltoall_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_count_array_t scounts, + ompi_disp_array_t sdisps, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, + struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, + struct ompi_info_t *info, ompi_request_t **request, + mca_coll_base_module_t *module); + +int mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, + struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, + struct ompi_info_t *info, ompi_request_t **request, + mca_coll_base_module_t *module); + +int mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_gatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, + struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf, size_t rcount, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + struct ompi_info_t *info, ompi_request_t **request, + mca_coll_base_module_t *module); + +int mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t scounts, + ompi_disp_array_t disps, struct ompi_datatype_t *sdtype, void *rbuf, + size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + END_C_DECLS #endif diff --git a/ompi/mca/coll/ucc/coll_ucc_allgather.c b/ompi/mca/coll/ucc/coll_ucc_allgather.c index 2dd3ac68a55..e2dda23660e 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgather.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgather.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,7 +10,7 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, +static inline ucc_status_t mca_coll_ucc_allgather_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, @@ -74,7 +75,7 @@ int mca_coll_ucc_allgather(const void *sbuf, size_t scount, struct ompi_datatype ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc allgather"); - COLL_UCC_CHECK(mca_coll_ucc_allgather_init(sbuf, scount, sdtype, + COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); @@ -98,7 +99,7 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp UCC_VERBOSE(3, "running ucc iallgather"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allgather_init(sbuf, scount, sdtype, + COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); diff --git a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c index 68e786e0c2a..ff92d245a9d 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,7 +10,7 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, +static inline ucc_status_t mca_coll_ucc_allgatherv_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, struct ompi_datatype_t *rdtype, @@ -75,7 +76,7 @@ int mca_coll_ucc_allgatherv(const void *sbuf, size_t scount, UCC_VERBOSE(3, "running ucc allgatherv"); - COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype, + COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); @@ -102,7 +103,7 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount, UCC_VERBOSE(3, "running ucc iallgatherv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype, + COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); diff --git a/ompi/mca/coll/ucc/coll_ucc_allreduce.c b/ompi/mca/coll/ucc/coll_ucc_allreduce.c index d44b93df07e..af43635fa96 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allreduce.c +++ b/ompi/mca/coll/ucc/coll_ucc_allreduce.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,7 +10,7 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, size_t count, +static inline ucc_status_t mca_coll_ucc_allreduce_iniz(const void *sbuf, void *rbuf, size_t count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, @@ -67,7 +68,7 @@ int mca_coll_ucc_allreduce(const void *sbuf, void *rbuf, size_t count, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc allreduce"); - COLL_UCC_CHECK(mca_coll_ucc_allreduce_init(sbuf, rbuf, count, dtype, op, + COLL_UCC_CHECK(mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); @@ -90,7 +91,7 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, size_t count, UCC_VERBOSE(3, "running ucc iallreduce"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allreduce_init(sbuf, rbuf, count, dtype, op, + COLL_UCC_CHECK(mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoall.c b/ompi/mca/coll/ucc/coll_ucc_alltoall.c index cfb56f47418..52eb2b1ec32 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoall.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoall.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,7 +10,7 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, +static inline ucc_status_t mca_coll_ucc_alltoall_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, @@ -74,7 +75,7 @@ int mca_coll_ucc_alltoall(const void *sbuf, size_t scount, struct ompi_datatype_ ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc alltoall"); - COLL_UCC_CHECK(mca_coll_ucc_alltoall_init(sbuf, scount, sdtype, + COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); @@ -98,7 +99,7 @@ int mca_coll_ucc_ialltoall(const void *sbuf, size_t scount, struct ompi_datatype UCC_VERBOSE(3, "running ucc ialltoall"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_alltoall_init(sbuf, scount, sdtype, + COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c index 1e9e311cf94..8ba51deaa51 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,7 +10,7 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_count_array_t scounts, +static inline ucc_status_t mca_coll_ucc_alltoallv_iniz(const void *sbuf, ompi_count_array_t scounts, ompi_disp_array_t sdisps, struct ompi_datatype_t *sdtype, void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, struct ompi_datatype_t *rdtype, @@ -77,7 +78,7 @@ int mca_coll_ucc_alltoallv(const void *sbuf, ompi_count_array_t scounts, UCC_VERBOSE(3, "running ucc alltoallv"); - COLL_UCC_CHECK(mca_coll_ucc_alltoallv_init(sbuf, scounts, sdisps, sdtype, + COLL_UCC_CHECK(mca_coll_ucc_alltoallv_iniz(sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, rdtype, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); @@ -104,7 +105,7 @@ int mca_coll_ucc_ialltoallv(const void *sbuf, ompi_count_array_t scounts, UCC_VERBOSE(3, "running ucc ialltoallv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_alltoallv_init(sbuf, scounts, sdisps, sdtype, + COLL_UCC_CHECK(mca_coll_ucc_alltoallv_iniz(sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, rdtype, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); diff --git a/ompi/mca/coll/ucc/coll_ucc_barrier.c b/ompi/mca/coll/ucc/coll_ucc_barrier.c index 010ca177fdf..def2d1f2ef3 100644 --- a/ompi/mca/coll/ucc/coll_ucc_barrier.c +++ b/ompi/mca/coll/ucc/coll_ucc_barrier.c @@ -1,5 +1,6 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -8,7 +9,7 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_barrier_init(mca_coll_ucc_module_t *ucc_module, +static inline ucc_status_t mca_coll_ucc_barrier_iniz(mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { @@ -30,7 +31,7 @@ int mca_coll_ucc_barrier(struct ompi_communicator_t *comm, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc barrier"); - COLL_UCC_CHECK(mca_coll_ucc_barrier_init(ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_barrier_iniz(ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -49,7 +50,7 @@ int mca_coll_ucc_ibarrier(struct ompi_communicator_t *comm, UCC_VERBOSE(3, "running ucc ibarrier"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_barrier_init(ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_barrier_iniz(ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; diff --git a/ompi/mca/coll/ucc/coll_ucc_bcast.c b/ompi/mca/coll/ucc/coll_ucc_bcast.c index 34296d74b97..0474fce3f40 100644 --- a/ompi/mca/coll/ucc/coll_ucc_bcast.c +++ b/ompi/mca/coll/ucc/coll_ucc_bcast.c @@ -1,5 +1,6 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -8,7 +9,7 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_bcast_init(void *buf, size_t count, struct ompi_datatype_t *dtype, +static inline ucc_status_t mca_coll_ucc_bcast_iniz(void *buf, size_t count, struct ompi_datatype_t *dtype, int root, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) @@ -44,7 +45,7 @@ int mca_coll_ucc_bcast(void *buf, size_t count, struct ompi_datatype_t *dtype, mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module; ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc bcast"); - COLL_UCC_CHECK(mca_coll_ucc_bcast_init(buf, count, dtype, root, + COLL_UCC_CHECK(mca_coll_ucc_bcast_iniz(buf, count, dtype, root, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); @@ -66,7 +67,7 @@ int mca_coll_ucc_ibcast(void *buf, size_t count, struct ompi_datatype_t *dtype, UCC_VERBOSE(3, "running ucc ibcast"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_bcast_init(buf, count, dtype, root, + COLL_UCC_CHECK(mca_coll_ucc_bcast_iniz(buf, count, dtype, root, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; diff --git a/ompi/mca/coll/ucc/coll_ucc_gather.c b/ompi/mca/coll/ucc/coll_ucc_gather.c index ba91b40b189..bfc5ebd9d73 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gather.c +++ b/ompi/mca/coll/ucc/coll_ucc_gather.c @@ -2,6 +2,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -11,7 +12,7 @@ #include "coll_ucc_common.h" static inline -ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, +ucc_status_t mca_coll_ucc_gather_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, @@ -91,7 +92,7 @@ int mca_coll_ucc_gather(const void *sbuf, size_t scount, struct ompi_datatype_t ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc gather"); - COLL_UCC_CHECK(mca_coll_ucc_gather_init(sbuf, scount, sdtype, rbuf, rcount, + COLL_UCC_CHECK(mca_coll_ucc_gather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); @@ -116,7 +117,7 @@ int mca_coll_ucc_igather(const void *sbuf, size_t scount, struct ompi_datatype_t UCC_VERBOSE(3, "running ucc igather"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_gather_init(sbuf, scount, sdtype, rbuf, rcount, + COLL_UCC_CHECK(mca_coll_ucc_gather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); diff --git a/ompi/mca/coll/ucc/coll_ucc_gatherv.c b/ompi/mca/coll/ucc/coll_ucc_gatherv.c index 5a1da52356c..d6b307be557 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_gatherv.c @@ -2,6 +2,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -10,7 +11,7 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, +static inline ucc_status_t mca_coll_ucc_gatherv_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, struct ompi_datatype_t *rdtype, int root, mca_coll_ucc_module_t *ucc_module, @@ -83,7 +84,7 @@ int mca_coll_ucc_gatherv(const void *sbuf, size_t scount, struct ompi_datatype_t ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc gatherv"); - COLL_UCC_CHECK(mca_coll_ucc_gatherv_init(sbuf, scount, sdtype, rbuf, rcounts, + COLL_UCC_CHECK(mca_coll_ucc_gatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, root, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); @@ -109,7 +110,7 @@ int mca_coll_ucc_igatherv(const void *sbuf, size_t scount, struct ompi_datatype_ UCC_VERBOSE(3, "running ucc igatherv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_gatherv_init(sbuf, scount, sdtype, rbuf, rcounts, + COLL_UCC_CHECK(mca_coll_ucc_gatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, root, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce.c b/ompi/mca/coll/ucc/coll_ucc_reduce.c index 97b5d424ccf..7302b1205dd 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce.c @@ -1,5 +1,6 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -8,7 +9,7 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf, size_t count, +static inline ucc_status_t mca_coll_ucc_reduce_iniz(const void *sbuf, void *rbuf, size_t count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, mca_coll_ucc_module_t *ucc_module, @@ -69,7 +70,7 @@ int mca_coll_ucc_reduce(const void *sbuf, void* rbuf, size_t count, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc reduce"); - COLL_UCC_CHECK(mca_coll_ucc_reduce_init(sbuf, rbuf, count, dtype, op, + COLL_UCC_CHECK(mca_coll_ucc_reduce_iniz(sbuf, rbuf, count, dtype, op, root, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); @@ -93,7 +94,7 @@ int mca_coll_ucc_ireduce(const void *sbuf, void* rbuf, size_t count, UCC_VERBOSE(3, "running ucc ireduce"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_init(sbuf, rbuf, count, dtype, op, root, + COLL_UCC_CHECK(mca_coll_ucc_reduce_iniz(sbuf, rbuf, count, dtype, op, root, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c index dabc8f11d03..468e7e76467 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -10,7 +11,7 @@ #include "coll_ucc_common.h" static inline -ucc_status_t mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, +ucc_status_t mca_coll_ucc_reduce_scatter_iniz(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, struct ompi_datatype_t *dtype, struct ompi_op_t *op, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, @@ -83,7 +84,7 @@ int mca_coll_ucc_reduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_t ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc reduce_scatter"); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_init(sbuf, rbuf, rcounts, dtype, + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_iniz(sbuf, rbuf, rcounts, dtype, op, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); @@ -108,7 +109,7 @@ int mca_coll_ucc_ireduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_ UCC_VERBOSE(3, "running ucc ireduce_scatter"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_init(sbuf, rbuf, rcounts, dtype, + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_iniz(sbuf, rbuf, rcounts, dtype, op, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c index 781776e42ca..56a330bdd92 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -10,7 +11,7 @@ #include "coll_ucc_common.h" static inline -ucc_status_t mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf, +ucc_status_t mca_coll_ucc_reduce_scatter_block_iniz(const void *sbuf, void *rbuf, size_t rcount, struct ompi_datatype_t *dtype, struct ompi_op_t *op, @@ -74,7 +75,7 @@ int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, size_t rcoun ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc reduce scatter block"); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init(sbuf, rbuf, rcount, + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_iniz(sbuf, rbuf, rcount, dtype, op, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); @@ -100,7 +101,7 @@ int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, size_t rcou UCC_VERBOSE(3, "running ucc ireduce_scatter_block"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init(sbuf, rbuf, rcount, + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_iniz(sbuf, rbuf, rcount, dtype, op, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); diff --git a/ompi/mca/coll/ucc/coll_ucc_scatter.c b/ompi/mca/coll/ucc/coll_ucc_scatter.c index 481365f22bd..f0d224284bb 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatter.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatter.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -10,7 +11,7 @@ #include "coll_ucc_common.h" static inline -ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, +ucc_status_t mca_coll_ucc_scatter_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, @@ -93,7 +94,7 @@ int mca_coll_ucc_scatter(const void *sbuf, size_t scount, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc scatter"); - COLL_UCC_CHECK(mca_coll_ucc_scatter_init(sbuf, scount, sdtype, rbuf, rcount, + COLL_UCC_CHECK(mca_coll_ucc_scatter_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); @@ -120,7 +121,7 @@ int mca_coll_ucc_iscatter(const void *sbuf, size_t scount, UCC_VERBOSE(3, "running ucc iscatter"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_scatter_init(sbuf, scount, sdtype, rbuf, rcount, + COLL_UCC_CHECK(mca_coll_ucc_scatter_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); diff --git a/ompi/mca/coll/ucc/coll_ucc_scatterv.c b/ompi/mca/coll/ucc/coll_ucc_scatterv.c index 36d4086a113..236157e358f 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatterv.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatterv.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -10,7 +11,7 @@ #include "coll_ucc_common.h" static inline -ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t scounts, +ucc_status_t mca_coll_ucc_scatterv_iniz(const void *sbuf, ompi_count_array_t scounts, ompi_disp_array_t disps, struct ompi_datatype_t *sdtype, void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, @@ -85,7 +86,7 @@ int mca_coll_ucc_scatterv(const void *sbuf, ompi_count_array_t scounts, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc scatterv"); - COLL_UCC_CHECK(mca_coll_ucc_scatterv_init(sbuf, scounts, disps, sdtype, + COLL_UCC_CHECK(mca_coll_ucc_scatterv_iniz(sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, root, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); @@ -112,7 +113,7 @@ int mca_coll_ucc_iscatterv(const void *sbuf, ompi_count_array_t scounts, UCC_VERBOSE(3, "running ucc iscatterv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_scatterv_init(sbuf, scounts, disps, sdtype, + COLL_UCC_CHECK(mca_coll_ucc_scatterv_iniz(sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, root, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); From 9305110b34e0d40511db95933c189e3efc99eaa1 Mon Sep 17 00:00:00 2001 From: "hasegawa.kento" Date: Thu, 21 Aug 2025 14:33:50 +0900 Subject: [PATCH 2/5] COLL/UCC: add persistent collective calls Signed-off-by: hasegawa.kento --- ompi/mca/coll/ucc/coll_ucc.h | 28 ++++++ ompi/mca/coll/ucc/coll_ucc_allgather.c | 45 ++++++++-- ompi/mca/coll/ucc/coll_ucc_allgatherv.c | 53 ++++++++--- ompi/mca/coll/ucc/coll_ucc_allreduce.c | 39 ++++++-- ompi/mca/coll/ucc/coll_ucc_alltoall.c | 45 ++++++++-- ompi/mca/coll/ucc/coll_ucc_alltoallv.c | 54 ++++++++--- ompi/mca/coll/ucc/coll_ucc_barrier.c | 33 ++++++- ompi/mca/coll/ucc/coll_ucc_bcast.c | 43 +++++++-- ompi/mca/coll/ucc/coll_ucc_common.h | 21 +++++ ompi/mca/coll/ucc/coll_ucc_gather.c | 49 +++++++--- ompi/mca/coll/ucc/coll_ucc_gatherv.c | 52 ++++++++--- ompi/mca/coll/ucc/coll_ucc_module.c | 90 ++++++++++++++++++- ompi/mca/coll/ucc/coll_ucc_reduce.c | 38 ++++++-- ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c | 49 +++++++--- .../coll/ucc/coll_ucc_reduce_scatter_block.c | 54 ++++++++--- ompi/mca/coll/ucc/coll_ucc_scatter.c | 52 ++++++++--- ompi/mca/coll/ucc/coll_ucc_scatterv.c | 54 ++++++++--- 17 files changed, 657 insertions(+), 142 deletions(-) diff --git a/ompi/mca/coll/ucc/coll_ucc.h b/ompi/mca/coll/ucc/coll_ucc.h index 7a9a5b4380d..a1ebf1c2bc5 100644 --- a/ompi/mca/coll/ucc/coll_ucc.h +++ b/ompi/mca/coll/ucc/coll_ucc.h @@ -133,6 +133,34 @@ struct mca_coll_ucc_module_t { mca_coll_base_module_t* previous_scatter_module; mca_coll_base_module_iscatter_fn_t previous_iscatter; mca_coll_base_module_t* previous_iscatter_module; + mca_coll_base_module_allreduce_init_fn_t previous_allreduce_init; + mca_coll_base_module_t *previous_allreduce_init_module; + mca_coll_base_module_reduce_init_fn_t previous_reduce_init; + mca_coll_base_module_t *previous_reduce_init_module; + mca_coll_base_module_barrier_init_fn_t previous_barrier_init; + mca_coll_base_module_t *previous_barrier_init_module; + mca_coll_base_module_bcast_init_fn_t previous_bcast_init; + mca_coll_base_module_t *previous_bcast_init_module; + mca_coll_base_module_alltoall_init_fn_t previous_alltoall_init; + mca_coll_base_module_t *previous_alltoall_init_module; + mca_coll_base_module_alltoallv_init_fn_t previous_alltoallv_init; + mca_coll_base_module_t *previous_alltoallv_init_module; + mca_coll_base_module_allgather_init_fn_t previous_allgather_init; + mca_coll_base_module_t *previous_allgather_init_module; + mca_coll_base_module_allgatherv_init_fn_t previous_allgatherv_init; + mca_coll_base_module_t *previous_allgatherv_init_module; + mca_coll_base_module_gather_init_fn_t previous_gather_init; + mca_coll_base_module_t *previous_gather_init_module; + mca_coll_base_module_gatherv_init_fn_t previous_gatherv_init; + mca_coll_base_module_t *previous_gatherv_init_module; + mca_coll_base_module_reduce_scatter_block_init_fn_t previous_reduce_scatter_block_init; + mca_coll_base_module_t *previous_reduce_scatter_block_init_module; + mca_coll_base_module_reduce_scatter_init_fn_t previous_reduce_scatter_init; + mca_coll_base_module_t *previous_reduce_scatter_init_module; + mca_coll_base_module_scatterv_init_fn_t previous_scatterv_init; + mca_coll_base_module_t *previous_scatterv_init_module; + mca_coll_base_module_scatter_init_fn_t previous_scatter_init; + mca_coll_base_module_t *previous_scatter_init_module; }; typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t; OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t); diff --git a/ompi/mca/coll/ucc/coll_ucc_allgather.c b/ompi/mca/coll/ucc/coll_ucc_allgather.c index e2dda23660e..f236ba401d2 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgather.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgather.c @@ -10,11 +10,11 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_allgather_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_allgather_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -60,6 +60,10 @@ static inline ucc_status_t mca_coll_ucc_allgather_iniz(const void *sbuf, size_t coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -75,8 +79,7 @@ int mca_coll_ucc_allgather(const void *sbuf, size_t scount, struct ompi_datatype ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc allgather"); - COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, - rbuf, rcount, rdtype, + COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); @@ -99,8 +102,7 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp UCC_VERBOSE(3, "running ucc iallgather"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, - rbuf, rcount, rdtype, + COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; @@ -113,3 +115,28 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp return ucc_module->previous_iallgather(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, request, ucc_module->previous_iallgather_module); } + +int mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "allgather_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, true, + ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback allgather_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_allgather_init(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, + info, request, + ucc_module->previous_allgather_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c index ff92d245a9d..b93ded16426 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c @@ -10,13 +10,12 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_allgatherv_iniz(const void *sbuf, size_t scount, - struct ompi_datatype_t *sdtype, - void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, - struct ompi_datatype_t *rdtype, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_allgatherv_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, + struct ompi_datatype_t *rdtype, bool persistent, + mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -58,6 +57,10 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_iniz(const void *sbuf, size_t } }; + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -76,9 +79,8 @@ int mca_coll_ucc_allgatherv(const void *sbuf, size_t scount, UCC_VERBOSE(3, "running ucc allgatherv"); - COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, - rbuf, rcounts, rdisps, rdtype, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -103,9 +105,8 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount, UCC_VERBOSE(3, "running ucc iallgatherv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, - rbuf, rcounts, rdisps, rdtype, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -118,3 +119,29 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount, rbuf, rcounts, rdisps, rdtype, comm, request, ucc_module->previous_iallgatherv_module); } + +int mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, + struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, + struct ompi_info_t *info, ompi_request_t **request, + mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "allgatherv_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, + true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback allgatherv_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_allgatherv_init(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, + comm, info, request, + ucc_module->previous_allgatherv_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_allreduce.c b/ompi/mca/coll/ucc/coll_ucc_allreduce.c index af43635fa96..65c68b60eb6 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allreduce.c +++ b/ompi/mca/coll/ucc/coll_ucc_allreduce.c @@ -12,7 +12,8 @@ static inline ucc_status_t mca_coll_ucc_allreduce_iniz(const void *sbuf, void *rbuf, size_t count, struct ompi_datatype_t *dtype, - struct ompi_op_t *op, mca_coll_ucc_module_t *ucc_module, + struct ompi_op_t *op, bool persistent, + mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { @@ -53,6 +54,10 @@ static inline ucc_status_t mca_coll_ucc_allreduce_iniz(const void *sbuf, void *r coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -68,8 +73,8 @@ int mca_coll_ucc_allreduce(const void *sbuf, void *rbuf, size_t count, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc allreduce"); - COLL_UCC_CHECK(mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, - ucc_module, &req, NULL)); + COLL_UCC_CHECK( + mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -91,8 +96,8 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, size_t count, UCC_VERBOSE(3, "running ucc iallreduce"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, false, ucc_module, + &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -104,3 +109,27 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, size_t count, return ucc_module->previous_iallreduce(sbuf, rbuf, count, dtype, op, comm, request, ucc_module->previous_iallreduce_module); } + +int mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "allreduce_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, true, ucc_module, &req, + coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback allreduce_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_allreduce_init(sbuf, rbuf, count, dtype, op, comm, info, request, + ucc_module->previous_allreduce_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoall.c b/ompi/mca/coll/ucc/coll_ucc_alltoall.c index 52eb2b1ec32..8f466b6fcf1 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoall.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoall.c @@ -10,11 +10,11 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_alltoall_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_alltoall_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -60,6 +60,10 @@ static inline ucc_status_t mca_coll_ucc_alltoall_iniz(const void *sbuf, size_t s coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -75,8 +79,7 @@ int mca_coll_ucc_alltoall(const void *sbuf, size_t scount, struct ompi_datatype_ ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc alltoall"); - COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, - rbuf, rcount, rdtype, + COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); @@ -99,8 +102,7 @@ int mca_coll_ucc_ialltoall(const void *sbuf, size_t scount, struct ompi_datatype UCC_VERBOSE(3, "running ucc ialltoall"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, - rbuf, rcount, rdtype, + COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; @@ -113,3 +115,28 @@ int mca_coll_ucc_ialltoall(const void *sbuf, size_t scount, struct ompi_datatype return ucc_module->previous_ialltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, request, ucc_module->previous_ialltoall_module); } + +int mca_coll_ucc_alltoall_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "alltoall_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, true, + ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback alltoall_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_alltoall_init(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, + info, request, + ucc_module->previous_alltoall_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c index 8ba51deaa51..0f531061676 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c @@ -10,13 +10,12 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_alltoallv_iniz(const void *sbuf, ompi_count_array_t scounts, - ompi_disp_array_t sdisps, struct ompi_datatype_t *sdtype, - void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, - struct ompi_datatype_t *rdtype, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_alltoallv_iniz(const void *sbuf, ompi_count_array_t scounts, ompi_disp_array_t sdisps, + struct ompi_datatype_t *sdtype, void *rbuf, ompi_count_array_t rcounts, + ompi_disp_array_t rdisps, struct ompi_datatype_t *rdtype, + bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -60,6 +59,10 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_iniz(const void *sbuf, ompi_co } }; + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -78,9 +81,8 @@ int mca_coll_ucc_alltoallv(const void *sbuf, ompi_count_array_t scounts, UCC_VERBOSE(3, "running ucc alltoallv"); - COLL_UCC_CHECK(mca_coll_ucc_alltoallv_iniz(sbuf, scounts, sdisps, sdtype, - rbuf, rcounts, rdisps, rdtype, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_alltoallv_iniz(sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, + rdtype, false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -105,9 +107,8 @@ int mca_coll_ucc_ialltoallv(const void *sbuf, ompi_count_array_t scounts, UCC_VERBOSE(3, "running ucc ialltoallv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_alltoallv_iniz(sbuf, scounts, sdisps, sdtype, - rbuf, rcounts, rdisps, rdtype, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_alltoallv_iniz(sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, + rdtype, false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -120,3 +121,30 @@ int mca_coll_ucc_ialltoallv(const void *sbuf, ompi_count_array_t scounts, rbuf, rcounts, rdisps, rdtype, comm, request, ucc_module->previous_ialltoallv_module); } + +int mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_count_array_t scounts, + ompi_disp_array_t sdisps, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, + struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, + struct ompi_info_t *info, ompi_request_t **request, + mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "alltoallv_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_alltoallv_iniz(sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, + rdtype, true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback alltoallv_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_alltoallv_init(sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, + rdtype, comm, info, request, + ucc_module->previous_alltoallv_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_barrier.c b/ompi/mca/coll/ucc/coll_ucc_barrier.c index def2d1f2ef3..9449332358c 100644 --- a/ompi/mca/coll/ucc/coll_ucc_barrier.c +++ b/ompi/mca/coll/ucc/coll_ucc_barrier.c @@ -9,7 +9,8 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_barrier_iniz(mca_coll_ucc_module_t *ucc_module, +static inline ucc_status_t mca_coll_ucc_barrier_iniz(bool persistent, + mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { @@ -18,6 +19,11 @@ static inline ucc_status_t mca_coll_ucc_barrier_iniz(mca_coll_ucc_module_t *ucc_ .flags = 0, .coll_type = UCC_COLL_TYPE_BARRIER }; + + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -31,7 +37,7 @@ int mca_coll_ucc_barrier(struct ompi_communicator_t *comm, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc barrier"); - COLL_UCC_CHECK(mca_coll_ucc_barrier_iniz(ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_barrier_iniz(false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -50,7 +56,7 @@ int mca_coll_ucc_ibarrier(struct ompi_communicator_t *comm, UCC_VERBOSE(3, "running ucc ibarrier"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_barrier_iniz(ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_barrier_iniz(false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -62,3 +68,24 @@ int mca_coll_ucc_ibarrier(struct ompi_communicator_t *comm, return ucc_module->previous_ibarrier(comm, request, ucc_module->previous_ibarrier_module); } + +int mca_coll_ucc_barrier_init(struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "barrier_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_barrier_iniz(true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback barrier_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_barrier_init(comm, info, request, + ucc_module->previous_barrier_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_bcast.c b/ompi/mca/coll/ucc/coll_ucc_bcast.c index 0474fce3f40..15c05e9acf3 100644 --- a/ompi/mca/coll/ucc/coll_ucc_bcast.c +++ b/ompi/mca/coll/ucc/coll_ucc_bcast.c @@ -9,10 +9,10 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_bcast_iniz(void *buf, size_t count, struct ompi_datatype_t *dtype, - int root, mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_bcast_iniz(void *buf, size_t count, struct ompi_datatype_t *dtype, int root, + bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt = ompi_dtype_to_ucc_dtype(dtype); if (COLL_UCC_DT_UNSUPPORTED == ucc_dt) { @@ -32,6 +32,11 @@ static inline ucc_status_t mca_coll_ucc_bcast_iniz(void *buf, size_t count, stru .mem_type = UCC_MEMORY_TYPE_UNKNOWN } }; + + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -45,8 +50,7 @@ int mca_coll_ucc_bcast(void *buf, size_t count, struct ompi_datatype_t *dtype, mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module; ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc bcast"); - COLL_UCC_CHECK(mca_coll_ucc_bcast_iniz(buf, count, dtype, root, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_bcast_iniz(buf, count, dtype, root, false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -67,8 +71,8 @@ int mca_coll_ucc_ibcast(void *buf, size_t count, struct ompi_datatype_t *dtype, UCC_VERBOSE(3, "running ucc ibcast"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_bcast_iniz(buf, count, dtype, root, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK( + mca_coll_ucc_bcast_iniz(buf, count, dtype, root, false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -80,3 +84,26 @@ int mca_coll_ucc_ibcast(void *buf, size_t count, struct ompi_datatype_t *dtype, return ucc_module->previous_ibcast(buf, count, dtype, root, comm, request, ucc_module->previous_ibcast_module); } + +int mca_coll_ucc_bcast_init(void *buf, size_t count, struct ompi_datatype_t *dtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "bcast_init init %p", coll_req); + COLL_UCC_CHECK( + mca_coll_ucc_bcast_iniz(buf, count, dtype, root, true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback bcast_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_bcast_init(buf, count, dtype, root, comm, info, request, + ucc_module->previous_bcast_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_common.h b/ompi/mca/coll/ucc/coll_ucc_common.h index 9d9163aa46d..73eaed79e30 100644 --- a/ompi/mca/coll/ucc/coll_ucc_common.h +++ b/ompi/mca/coll/ucc/coll_ucc_common.h @@ -1,5 +1,6 @@ /** Copyright (c) 2021 Mellanox Technologies. All rights reserved. + Copyright (c) 2025 Fujitsu Limited. All rights reserved. $COPYRIGHT$ Additional copyrights may follow $HEADER$ @@ -42,6 +43,25 @@ _coll_req->super.req_type = OMPI_REQUEST_COLL; \ } while(0) +#define COLL_UCC_GET_REQ_PC(_coll_req) \ + do { \ + opal_free_list_item_t *item; \ + item = opal_free_list_wait(&mca_coll_ucc_component.requests); \ + if (OPAL_UNLIKELY(NULL == item)) { \ + UCC_ERROR("failed to get mca_coll_ucc_req from free_list"); \ + goto fallback; \ + } \ + _coll_req = (mca_coll_ucc_req_t *) item; \ + OMPI_REQUEST_INIT(&_coll_req->super, true); \ + _coll_req->super.req_complete_cb = NULL; \ + _coll_req->super.req_complete_cb_data = NULL; \ + _coll_req->super.req_status.MPI_ERROR = MPI_SUCCESS; \ + _coll_req->super.req_free = mca_coll_ucc_req_free; \ + _coll_req->super.req_start = mca_coll_ucc_req_start; \ + _coll_req->super.req_type = OMPI_REQUEST_COLL; \ + _coll_req->ucc_req = NULL; \ + } while (0) + #define COLL_UCC_REQ_INIT(_coll_req, _req, _coll, _module) do{ \ if (_coll_req) { \ _coll.mask |= UCC_COLL_ARGS_FIELD_CB; \ @@ -76,5 +96,6 @@ static inline ucc_status_t coll_ucc_req_wait(ucc_coll_req_h req) int mca_coll_ucc_req_free(struct ompi_request_t **ompi_req); void mca_coll_ucc_completion(void *data, ucc_status_t status); +int mca_coll_ucc_req_start(size_t count, struct ompi_request_t **requests); #endif diff --git a/ompi/mca/coll/ucc/coll_ucc_gather.c b/ompi/mca/coll/ucc/coll_ucc_gather.c index bfc5ebd9d73..6deecb317ae 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gather.c +++ b/ompi/mca/coll/ucc/coll_ucc_gather.c @@ -11,12 +11,11 @@ #include "coll_ucc_common.h" -static inline -ucc_status_t mca_coll_ucc_gather_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, - int root, mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_gather_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, + bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -77,6 +76,10 @@ ucc_status_t mca_coll_ucc_gather_iniz(const void *sbuf, size_t scount, struct om coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -92,9 +95,8 @@ int mca_coll_ucc_gather(const void *sbuf, size_t scount, struct ompi_datatype_t ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc gather"); - COLL_UCC_CHECK(mca_coll_ucc_gather_iniz(sbuf, scount, sdtype, rbuf, rcount, - rdtype, root, ucc_module, - &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_gather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, false, + ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -117,9 +119,8 @@ int mca_coll_ucc_igather(const void *sbuf, size_t scount, struct ompi_datatype_t UCC_VERBOSE(3, "running ucc igather"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_gather_iniz(sbuf, scount, sdtype, rbuf, rcount, - rdtype, root, ucc_module, - &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_gather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, false, + ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -132,3 +133,27 @@ int mca_coll_ucc_igather(const void *sbuf, size_t scount, struct ompi_datatype_t rdtype, root, comm, request, ucc_module->previous_igather_module); } + +int mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "gather_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_gather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, true, + ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback gather_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_gather_init(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, + info, request, ucc_module->previous_gather_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_gatherv.c b/ompi/mca/coll/ucc/coll_ucc_gatherv.c index d6b307be557..5bc0a18279b 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_gatherv.c @@ -11,12 +11,12 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_gatherv_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, - struct ompi_datatype_t *rdtype, int root, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_gatherv_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, + struct ompi_datatype_t *rdtype, int root, bool persistent, + mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -68,6 +68,10 @@ static inline ucc_status_t mca_coll_ucc_gatherv_iniz(const void *sbuf, size_t sc }, }; + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -84,9 +88,8 @@ int mca_coll_ucc_gatherv(const void *sbuf, size_t scount, struct ompi_datatype_t ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc gatherv"); - COLL_UCC_CHECK(mca_coll_ucc_gatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, - disps, rdtype, root, ucc_module, - &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_gatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, + root, false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -110,9 +113,8 @@ int mca_coll_ucc_igatherv(const void *sbuf, size_t scount, struct ompi_datatype_ UCC_VERBOSE(3, "running ucc igatherv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_gatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, - disps, rdtype, root, ucc_module, - &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_gatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, + root, false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -125,3 +127,29 @@ int mca_coll_ucc_igatherv(const void *sbuf, size_t scount, struct ompi_datatype_ disps, rdtype, root, comm, request, ucc_module->previous_igatherv_module); } + +int mca_coll_ucc_gatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, + struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "gatherv_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_gatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, + root, true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback gatherv_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_gatherv_init(sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, + root, comm, info, request, + ucc_module->previous_gatherv_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_module.c b/ompi/mca/coll/ucc/coll_ucc_module.c index d297274e8c9..a49f1bfaa84 100644 --- a/ompi/mca/coll/ucc/coll_ucc_module.c +++ b/ompi/mca/coll/ucc/coll_ucc_module.c @@ -4,6 +4,7 @@ * All Rights reserved. * Copyright (c) 2022-2025 NVIDIA Corporation. All rights reserved. * Copyright (c) 2024 Triad National Security, LLC. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -87,6 +88,34 @@ static void mca_coll_ucc_module_clear(mca_coll_ucc_module_t *ucc_module) ucc_module->previous_scatter_module = NULL; ucc_module->previous_iscatter = NULL; ucc_module->previous_iscatter_module = NULL; + ucc_module->previous_allreduce_init = NULL; + ucc_module->previous_allreduce_init_module = NULL; + ucc_module->previous_barrier_init = NULL; + ucc_module->previous_barrier_init_module = NULL; + ucc_module->previous_bcast_init = NULL; + ucc_module->previous_bcast_init_module = NULL; + ucc_module->previous_alltoall_init = NULL; + ucc_module->previous_alltoall_init_module = NULL; + ucc_module->previous_alltoallv_init = NULL; + ucc_module->previous_alltoallv_init_module = NULL; + ucc_module->previous_allgather_init = NULL; + ucc_module->previous_allgather_init_module = NULL; + ucc_module->previous_allgatherv_init = NULL; + ucc_module->previous_allgatherv_init_module = NULL; + ucc_module->previous_reduce_init = NULL; + ucc_module->previous_reduce_init_module = NULL; + ucc_module->previous_gather_init = NULL; + ucc_module->previous_gather_init_module = NULL; + ucc_module->previous_gatherv_init = NULL; + ucc_module->previous_gatherv_init_module = NULL; + ucc_module->previous_reduce_scatter_block_init = NULL; + ucc_module->previous_reduce_scatter_block_init_module = NULL; + ucc_module->previous_reduce_scatter_init = NULL; + ucc_module->previous_reduce_scatter_init_module = NULL; + ucc_module->previous_scatterv_init = NULL; + ucc_module->previous_scatterv_init_module = NULL; + ucc_module->previous_scatter_init = NULL; + ucc_module->previous_scatter_init_module = NULL; } static void mca_coll_ucc_module_construct(mca_coll_ucc_module_t *ucc_module) @@ -592,6 +621,19 @@ OBJ_CLASS_INSTANCE(mca_coll_ucc_req_t, ompi_request_t, int mca_coll_ucc_req_free(struct ompi_request_t **ompi_req) { + if (MPI_REQUEST_NULL != ompi_req[0]) { + mca_coll_ucc_req_t *coll_req = (mca_coll_ucc_req_t *) ompi_req[0]; + if (true == coll_req->super.req_persistent) { + UCC_VERBOSE(5, "%s free %p", "_init", coll_req); + if (NULL != coll_req->ucc_req) { + ucc_status_t rc_ucc; + rc_ucc = ucc_collective_finalize(coll_req->ucc_req); + if (UCC_OK != rc_ucc) { + UCC_ERROR("ucc_collective_finalize failed: %s", ucc_status_string(rc_ucc)); + } + } + } + } opal_free_list_return (&mca_coll_ucc_component.requests, (opal_free_list_item_t *)(*ompi_req)); *ompi_req = MPI_REQUEST_NULL; @@ -602,6 +644,52 @@ int mca_coll_ucc_req_free(struct ompi_request_t **ompi_req) void mca_coll_ucc_completion(void *data, ucc_status_t status) { mca_coll_ucc_req_t *coll_req = (mca_coll_ucc_req_t*)data; - ucc_collective_finalize(coll_req->ucc_req); + if (false == coll_req->super.req_persistent) { + ucc_collective_finalize(coll_req->ucc_req); + } else { + UCC_VERBOSE(5, "%s done %p", "_init", coll_req); + assert(!REQUEST_COMPLETE(&coll_req->super)); + } ompi_request_complete(&coll_req->super, true); } + +/* req_start() : ompi_request_start_fn_t */ +int mca_coll_ucc_req_start(size_t count, struct ompi_request_t **requests) +{ + size_t ii; + int rc = OMPI_SUCCESS; + + for (ii = 0; ii < count; ++ii) { + mca_coll_ucc_req_t *coll_req = (mca_coll_ucc_req_t *) requests[ii]; + ucc_status_t rc_ucc; + + if ((NULL == coll_req) || (OMPI_REQUEST_COLL != coll_req->super.req_type)) { + continue; + } + if (true != coll_req->super.req_persistent) { + continue; + } + UCC_VERBOSE(5, "%s post %p", "_init", coll_req); + assert(REQUEST_COMPLETE(&coll_req->super)); + assert(OMPI_REQUEST_INACTIVE == coll_req->super.req_state); + + coll_req->super.req_status.MPI_TAG = MPI_ANY_TAG; + coll_req->super.req_status.MPI_ERROR = OMPI_SUCCESS; + coll_req->super.req_status._cancelled = 0; + coll_req->super.req_complete = REQUEST_PENDING; + coll_req->super.req_state = OMPI_REQUEST_ACTIVE; + + rc_ucc = ucc_collective_post(coll_req->ucc_req); + if (UCC_OK != rc_ucc) { + UCC_ERROR("ucc_collective_post failed: %s", ucc_status_string(rc_ucc)); + coll_req->super.req_complete = REQUEST_COMPLETED; + coll_req->super.req_state = OMPI_REQUEST_INACTIVE; + if (OMPI_SUCCESS == rc) { + rc = OMPI_ERROR; + } + continue; + } + } + + return rc; +} diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce.c b/ompi/mca/coll/ucc/coll_ucc_reduce.c index 7302b1205dd..ed719525334 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce.c @@ -11,7 +11,7 @@ static inline ucc_status_t mca_coll_ucc_reduce_iniz(const void *sbuf, void *rbuf, size_t count, struct ompi_datatype_t *dtype, - struct ompi_op_t *op, int root, + struct ompi_op_t *op, int root, bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) @@ -54,6 +54,10 @@ static inline ucc_status_t mca_coll_ucc_reduce_iniz(const void *sbuf, void *rbuf coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -70,8 +74,8 @@ int mca_coll_ucc_reduce(const void *sbuf, void* rbuf, size_t count, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc reduce"); - COLL_UCC_CHECK(mca_coll_ucc_reduce_iniz(sbuf, rbuf, count, dtype, op, - root, ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_iniz(sbuf, rbuf, count, dtype, op, root, false, ucc_module, + &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -94,8 +98,8 @@ int mca_coll_ucc_ireduce(const void *sbuf, void* rbuf, size_t count, UCC_VERBOSE(3, "running ucc ireduce"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_iniz(sbuf, rbuf, count, dtype, op, root, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_iniz(sbuf, rbuf, count, dtype, op, root, false, ucc_module, + &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -107,3 +111,27 @@ int mca_coll_ucc_ireduce(const void *sbuf, void* rbuf, size_t count, return ucc_module->previous_ireduce(sbuf, rbuf, count, dtype, op, root, comm, request, ucc_module->previous_ireduce_module); } + +int mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "reduce_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_reduce_iniz(sbuf, rbuf, count, dtype, op, root, true, ucc_module, + &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback reduce_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_reduce_init(sbuf, rbuf, count, dtype, op, root, comm, info, request, + ucc_module->previous_reduce_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c index 468e7e76467..fe1a591b474 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c @@ -10,12 +10,11 @@ #include "coll_ucc_common.h" -static inline -ucc_status_t mca_coll_ucc_reduce_scatter_iniz(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, - struct ompi_datatype_t *dtype, - struct ompi_op_t *op, mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_reduce_scatter_iniz(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; @@ -68,6 +67,11 @@ ucc_status_t mca_coll_ucc_reduce_scatter_iniz(const void *sbuf, void *rbuf, ompi }, .op = ucc_op, }; + + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -84,8 +88,8 @@ int mca_coll_ucc_reduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_t ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc reduce_scatter"); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_iniz(sbuf, rbuf, rcounts, dtype, - op, ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_iniz(sbuf, rbuf, rcounts, dtype, op, false, + ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -109,8 +113,8 @@ int mca_coll_ucc_ireduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_ UCC_VERBOSE(3, "running ucc ireduce_scatter"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_iniz(sbuf, rbuf, rcounts, dtype, - op, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_iniz(sbuf, rbuf, rcounts, dtype, op, false, + ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -123,3 +127,28 @@ int mca_coll_ucc_ireduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_ comm, request, ucc_module->previous_ireduce_scatter_module); } + +int mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "reduce_scatter_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_iniz(sbuf, rbuf, rcounts, dtype, op, true, + ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback reduce_scatter_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module + ->previous_reduce_scatter_init(sbuf, rbuf, rcounts, dtype, op, comm, info, request, + ucc_module->previous_reduce_scatter_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c index 56a330bdd92..1d88647a24a 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c @@ -10,14 +10,11 @@ #include "coll_ucc_common.h" -static inline -ucc_status_t mca_coll_ucc_reduce_scatter_block_iniz(const void *sbuf, void *rbuf, - size_t rcount, - struct ompi_datatype_t *dtype, - struct ompi_op_t *op, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_reduce_scatter_block_iniz(const void *sbuf, void *rbuf, size_t rcount, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; @@ -59,6 +56,11 @@ ucc_status_t mca_coll_ucc_reduce_scatter_block_iniz(const void *sbuf, void *rbuf }, .op = ucc_op, }; + + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -75,9 +77,8 @@ int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, size_t rcoun ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc reduce scatter block"); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_iniz(sbuf, rbuf, rcount, - dtype, op, ucc_module, - &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_iniz(sbuf, rbuf, rcount, dtype, op, false, + ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -101,9 +102,8 @@ int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, size_t rcou UCC_VERBOSE(3, "running ucc ireduce_scatter_block"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_iniz(sbuf, rbuf, rcount, - dtype, op, ucc_module, - &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_iniz(sbuf, rbuf, rcount, dtype, op, false, + ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -116,3 +116,29 @@ int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, size_t rcou op, comm, request, ucc_module->previous_ireduce_scatter_block_module); } + +int mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf, size_t rcount, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + struct ompi_info_t *info, ompi_request_t **request, + mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "reduce_scatter_block_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_iniz(sbuf, rbuf, rcount, dtype, op, true, + ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback reduce_scatter_block_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module + ->previous_reduce_scatter_block_init(sbuf, rbuf, rcount, dtype, op, comm, info, request, + ucc_module->previous_reduce_scatter_block_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_scatter.c b/ompi/mca/coll/ucc/coll_ucc_scatter.c index f0d224284bb..bf621a3f01a 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatter.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatter.c @@ -10,14 +10,11 @@ #include "coll_ucc_common.h" -static inline -ucc_status_t mca_coll_ucc_scatter_iniz(const void *sbuf, size_t scount, - struct ompi_datatype_t *sdtype, - void *rbuf, size_t rcount, - struct ompi_datatype_t *rdtype, int root, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_scatter_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, + bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == rbuf); @@ -78,6 +75,10 @@ ucc_status_t mca_coll_ucc_scatter_iniz(const void *sbuf, size_t scount, coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -94,9 +95,8 @@ int mca_coll_ucc_scatter(const void *sbuf, size_t scount, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc scatter"); - COLL_UCC_CHECK(mca_coll_ucc_scatter_iniz(sbuf, scount, sdtype, rbuf, rcount, - rdtype, root, ucc_module, &req, - NULL)); + COLL_UCC_CHECK(mca_coll_ucc_scatter_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -121,9 +121,8 @@ int mca_coll_ucc_iscatter(const void *sbuf, size_t scount, UCC_VERBOSE(3, "running ucc iscatter"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_scatter_iniz(sbuf, scount, sdtype, rbuf, rcount, - rdtype, root, ucc_module, &req, - coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_scatter_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -136,3 +135,28 @@ int mca_coll_ucc_iscatter(const void *sbuf, size_t scount, rdtype, root, comm, request, ucc_module->previous_iscatter_module); } + +int mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "scatter_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_scatter_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, true, + ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback scatter_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_scatter_init(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, + info, request, + ucc_module->previous_scatter_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_scatterv.c b/ompi/mca/coll/ucc/coll_ucc_scatterv.c index 236157e358f..268a2e4599a 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatterv.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatterv.c @@ -10,14 +10,12 @@ #include "coll_ucc_common.h" -static inline -ucc_status_t mca_coll_ucc_scatterv_iniz(const void *sbuf, ompi_count_array_t scounts, - ompi_disp_array_t disps, struct ompi_datatype_t *sdtype, - void *rbuf, size_t rcount, - struct ompi_datatype_t *rdtype, int root, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_scatterv_iniz(const void *sbuf, ompi_count_array_t scounts, ompi_disp_array_t disps, + struct ompi_datatype_t *sdtype, void *rbuf, size_t rcount, + struct ompi_datatype_t *rdtype, int root, bool persistent, + mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == rbuf); @@ -69,6 +67,10 @@ ucc_status_t mca_coll_ucc_scatterv_iniz(const void *sbuf, ompi_count_array_t sco }, }; + if (true == persistent) { + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -86,9 +88,8 @@ int mca_coll_ucc_scatterv(const void *sbuf, ompi_count_array_t scounts, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc scatterv"); - COLL_UCC_CHECK(mca_coll_ucc_scatterv_iniz(sbuf, scounts, disps, sdtype, - rbuf, rcount, rdtype, root, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_scatterv_iniz(sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, + root, false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -113,9 +114,8 @@ int mca_coll_ucc_iscatterv(const void *sbuf, ompi_count_array_t scounts, UCC_VERBOSE(3, "running ucc iscatterv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_scatterv_iniz(sbuf, scounts, disps, sdtype, - rbuf, rcount, rdtype, root, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_scatterv_iniz(sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, + root, false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -128,3 +128,29 @@ int mca_coll_ucc_iscatterv(const void *sbuf, ompi_count_array_t scounts, rcount, rdtype, root, comm, request, ucc_module->previous_iscatterv_module); } + +int mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t scounts, + ompi_disp_array_t disps, struct ompi_datatype_t *sdtype, void *rbuf, + size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PC(coll_req); + UCC_VERBOSE(3, "scatterv_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_scatterv_iniz(sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, + root, true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback scatterv_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_scatterv_init(sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, + root, comm, info, request, + ucc_module->previous_scatterv_init_module); +} From 4171fdb0af4865a30b8613b0425b1a5e360dfdf1 Mon Sep 17 00:00:00 2001 From: "hasegawa.kento" Date: Thu, 21 Aug 2025 14:33:50 +0900 Subject: [PATCH 3/5] COLL/UCC: add _init to -mca coll_ucc_cts parser Signed-off-by: hasegawa.kento --- ompi/mca/coll/ucc/coll_ucc.h | 1 + ompi/mca/coll/ucc/coll_ucc_component.c | 75 +++++++++++++++++++++++++- ompi/mca/coll/ucc/coll_ucc_module.c | 27 ++++++++++ 3 files changed, 102 insertions(+), 1 deletion(-) diff --git a/ompi/mca/coll/ucc/coll_ucc.h b/ompi/mca/coll/ucc/coll_ucc.h index a1ebf1c2bc5..a80f7a9196a 100644 --- a/ompi/mca/coll/ucc/coll_ucc.h +++ b/ompi/mca/coll/ucc/coll_ucc.h @@ -62,6 +62,7 @@ struct mca_coll_ucc_component_t { ucc_lib_attr_t ucc_lib_attr; ucc_coll_type_t cts_requested; ucc_coll_type_t nb_cts_requested; + ucc_coll_type_t pc_cts_requested; ucc_context_h ucc_context; opal_free_list_t requests; }; diff --git a/ompi/mca/coll/ucc/coll_ucc_component.c b/ompi/mca/coll/ucc/coll_ucc_component.c index 2f065c8404e..9ceebb3fcd7 100644 --- a/ompi/mca/coll/ucc/coll_ucc_component.c +++ b/ompi/mca/coll/ucc/coll_ucc_component.c @@ -3,6 +3,7 @@ * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. * Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -143,6 +144,63 @@ static ucc_coll_type_t mca_coll_ucc_str_to_type(const char *str) return UCC_COLL_TYPE_LAST; } +/* is a persistent collective */ +static inline int mca_coll_ucc_init_cts_is_p(const char *cp, char *bp, size_t bz) +{ + size_t len = strlen(cp), len_suffix = sizeof("_init") - 1; + + if ((bz > 0) && (bp != 0)) { + bp[0] = '\0'; + } + /* check if it is a persistent collective */ + if (len > len_suffix) { + size_t blen = len - len_suffix; + const char *cp_suffix = &cp[blen]; + + if (0 == strcmp(cp_suffix, "_init")) { + if ((bz > 0) && (bp != 0)) { + if (blen >= bz) { + return 0 /* XXX internal error */; + } + strncpy(bp, cp, blen); + bp[blen] = '\0'; + } + return 1 /* true */; + } + } + return 0 /* false */; +} + +/* is an alias (special) name */ +static inline int mca_coll_ucc_init_cts_is_a(const char *cp, bool disable, + mca_coll_ucc_component_t *cm) +{ + if (0 == strcmp(cp, "colls_b")) { /* all blocking colls */ + if (disable) { + cm->cts_requested &= ~COLL_UCC_CTS; + } else { + cm->cts_requested |= COLL_UCC_CTS; + } + return 1 /* true */; + } else if ((0 == strcmp(cp, "colls_i")) || (0 == strcmp(cp, "colls_nb"))) { + /* all non-blocking colls */ + if (disable) { + cm->nb_cts_requested &= ~COLL_UCC_CTS; + } else { + cm->nb_cts_requested |= COLL_UCC_CTS; + } + return 1 /* true */; + } else if (0 == strcmp(cp, "colls_p")) { /* all persistent colls */ + if (disable) { + cm->pc_cts_requested &= ~COLL_UCC_CTS; + } else { + cm->pc_cts_requested |= COLL_UCC_CTS; + } + return 1 /* true */; + } + return 0 /* false */; +} + static void mca_coll_ucc_init_default_cts(void) { mca_coll_ucc_component_t *cm = &mca_coll_ucc_component; @@ -157,18 +215,33 @@ static void mca_coll_ucc_init_default_cts(void) n_cts = opal_argv_count(cts); cm->cts_requested = disable ? COLL_UCC_CTS : 0; cm->nb_cts_requested = disable ? COLL_UCC_CTS : 0; + cm->pc_cts_requested = 0; /* XXX PC currently disabled by default */ for (i = 0; i < n_cts; i++) { + char l_str[64]; /* XXX sizeof("reduce_scatter_block") */ + size_t l_stz = sizeof(l_str); + + if (0 < mca_coll_ucc_init_cts_is_a(cts[i], disable, cm)) { + continue; + } if (('i' == cts[i][0]) || ('I' == cts[i][0])) { /* non blocking collective setting */ str = cts[i] + 1; ct = &cm->nb_cts_requested; + } else if (0 < mca_coll_ucc_init_cts_is_p(cts[i], l_str, l_stz)) { + /* persistent collective setting */ + str = l_str; + ct = &cm->pc_cts_requested; } else { str = cts[i]; ct = &cm->cts_requested; } c = mca_coll_ucc_str_to_type(str); if (UCC_COLL_TYPE_LAST == c) { - *ct = COLL_UCC_CTS; + if (&cm->pc_cts_requested != ct) { + *ct = COLL_UCC_CTS; + } else { + *ct = 0; /* XXX PC currently disabled by default */ + } break; } if (disable) { diff --git a/ompi/mca/coll/ucc/coll_ucc_module.c b/ompi/mca/coll/ucc/coll_ucc_module.c index a49f1bfaa84..c3c0b11e3ec 100644 --- a/ompi/mca/coll/ucc/coll_ucc_module.c +++ b/ompi/mca/coll/ucc/coll_ucc_module.c @@ -425,6 +425,12 @@ static inline ucc_ep_map_t get_rank_map(struct ompi_communicator_t *comm) MCA_COLL_INSTALL_API(__comm, i##__api, mca_coll_ucc_i##__api, &__ucc_module->super, "ucc"); \ (__ucc_module)->super.coll_i##__api = mca_coll_ucc_i##__api; \ } \ + if (mca_coll_ucc_component.pc_cts_requested & UCC_COLL_TYPE_##__COLL) \ + { \ + MCA_COLL_SAVE_API(__comm, __api##_init, (__ucc_module)->previous_##__api##_init, (__ucc_module)->previous_##__api##_init_module, "ucc"); \ + MCA_COLL_INSTALL_API(__comm, __api##_init, mca_coll_ucc_##__api##_init, &__ucc_module->super, "ucc"); \ + (__ucc_module)->super.coll_##__api##_init = mca_coll_ucc_##__api##_init; \ + } \ } \ } while (0) @@ -559,11 +565,32 @@ mca_coll_ucc_module_disable(mca_coll_base_module_t *module, UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce); UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce); UCC_UNINSTALL_COLL_API(comm, ucc_module, gather); + /* UCC_UNINSTALL_COLL_API(comm, ucc_module, igather); */ UCC_UNINSTALL_COLL_API(comm, ucc_module, gatherv); + /* UCC_UNINSTALL_COLL_API(comm, ucc_module, igatherv); */ UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter_block); + /* UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce_scatter_block); */ UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter); + /* UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce_scatter); */ UCC_UNINSTALL_COLL_API(comm, ucc_module, scatter); + /* UCC_UNINSTALL_COLL_API(comm, ucc_module, iscatter); */ UCC_UNINSTALL_COLL_API(comm, ucc_module, scatterv); + /* UCC_UNINSTALL_COLL_API(comm, ucc_module, iscatterv); */ + + UCC_UNINSTALL_COLL_API(comm, ucc_module, allreduce_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, barrier_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, bcast_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, alltoall_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, alltoallv_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, allgather_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, allgatherv_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, gather_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, gatherv_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter_block_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, scatter_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, scatterv_init); return OMPI_SUCCESS; } From 1b4bf919adf1907df0ad27434e0616081ea1fd48 Mon Sep 17 00:00:00 2001 From: "hasegawa.kento" Date: Mon, 1 Sep 2025 14:39:41 +0900 Subject: [PATCH 4/5] COLL/UCC: fix coding style and error handling Signed-off-by: hasegawa.kento --- ompi/mca/coll/ucc/coll_ucc.h | 56 +++++++------- ompi/mca/coll/ucc/coll_ucc_allgather.c | 26 ++++--- ompi/mca/coll/ucc/coll_ucc_allgatherv.c | 36 ++++----- ompi/mca/coll/ucc/coll_ucc_allreduce.c | 25 +++--- ompi/mca/coll/ucc/coll_ucc_alltoall.c | 28 ++++--- ompi/mca/coll/ucc/coll_ucc_alltoallv.c | 38 +++++----- ompi/mca/coll/ucc/coll_ucc_barrier.c | 15 ++-- ompi/mca/coll/ucc/coll_ucc_bcast.c | 22 +++--- ompi/mca/coll/ucc/coll_ucc_common.h | 2 +- ompi/mca/coll/ucc/coll_ucc_component.c | 37 ++++----- ompi/mca/coll/ucc/coll_ucc_gather.c | 26 ++++--- ompi/mca/coll/ucc/coll_ucc_gatherv.c | 35 ++++----- ompi/mca/coll/ucc/coll_ucc_module.c | 76 ++++++++++--------- ompi/mca/coll/ucc/coll_ucc_reduce.c | 26 +++---- ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c | 31 ++++---- .../coll/ucc/coll_ucc_reduce_scatter_block.c | 28 ++++--- ompi/mca/coll/ucc/coll_ucc_scatter.c | 28 ++++--- ompi/mca/coll/ucc/coll_ucc_scatterv.c | 36 ++++----- 18 files changed, 297 insertions(+), 274 deletions(-) diff --git a/ompi/mca/coll/ucc/coll_ucc.h b/ompi/mca/coll/ucc/coll_ucc.h index a80f7a9196a..da2d1d2e141 100644 --- a/ompi/mca/coll/ucc/coll_ucc.h +++ b/ompi/mca/coll/ucc/coll_ucc.h @@ -62,7 +62,7 @@ struct mca_coll_ucc_component_t { ucc_lib_attr_t ucc_lib_attr; ucc_coll_type_t cts_requested; ucc_coll_type_t nb_cts_requested; - ucc_coll_type_t pc_cts_requested; + ucc_coll_type_t ps_cts_requested; ucc_context_h ucc_context; opal_free_list_t requests; }; @@ -134,34 +134,34 @@ struct mca_coll_ucc_module_t { mca_coll_base_module_t* previous_scatter_module; mca_coll_base_module_iscatter_fn_t previous_iscatter; mca_coll_base_module_t* previous_iscatter_module; - mca_coll_base_module_allreduce_init_fn_t previous_allreduce_init; - mca_coll_base_module_t *previous_allreduce_init_module; - mca_coll_base_module_reduce_init_fn_t previous_reduce_init; - mca_coll_base_module_t *previous_reduce_init_module; - mca_coll_base_module_barrier_init_fn_t previous_barrier_init; - mca_coll_base_module_t *previous_barrier_init_module; - mca_coll_base_module_bcast_init_fn_t previous_bcast_init; - mca_coll_base_module_t *previous_bcast_init_module; - mca_coll_base_module_alltoall_init_fn_t previous_alltoall_init; - mca_coll_base_module_t *previous_alltoall_init_module; - mca_coll_base_module_alltoallv_init_fn_t previous_alltoallv_init; - mca_coll_base_module_t *previous_alltoallv_init_module; - mca_coll_base_module_allgather_init_fn_t previous_allgather_init; - mca_coll_base_module_t *previous_allgather_init_module; - mca_coll_base_module_allgatherv_init_fn_t previous_allgatherv_init; - mca_coll_base_module_t *previous_allgatherv_init_module; - mca_coll_base_module_gather_init_fn_t previous_gather_init; - mca_coll_base_module_t *previous_gather_init_module; - mca_coll_base_module_gatherv_init_fn_t previous_gatherv_init; - mca_coll_base_module_t *previous_gatherv_init_module; + mca_coll_base_module_allreduce_init_fn_t previous_allreduce_init; + mca_coll_base_module_t* previous_allreduce_init_module; + mca_coll_base_module_reduce_init_fn_t previous_reduce_init; + mca_coll_base_module_t* previous_reduce_init_module; + mca_coll_base_module_barrier_init_fn_t previous_barrier_init; + mca_coll_base_module_t* previous_barrier_init_module; + mca_coll_base_module_bcast_init_fn_t previous_bcast_init; + mca_coll_base_module_t* previous_bcast_init_module; + mca_coll_base_module_alltoall_init_fn_t previous_alltoall_init; + mca_coll_base_module_t* previous_alltoall_init_module; + mca_coll_base_module_alltoallv_init_fn_t previous_alltoallv_init; + mca_coll_base_module_t* previous_alltoallv_init_module; + mca_coll_base_module_allgather_init_fn_t previous_allgather_init; + mca_coll_base_module_t* previous_allgather_init_module; + mca_coll_base_module_allgatherv_init_fn_t previous_allgatherv_init; + mca_coll_base_module_t* previous_allgatherv_init_module; + mca_coll_base_module_gather_init_fn_t previous_gather_init; + mca_coll_base_module_t* previous_gather_init_module; + mca_coll_base_module_gatherv_init_fn_t previous_gatherv_init; + mca_coll_base_module_t* previous_gatherv_init_module; mca_coll_base_module_reduce_scatter_block_init_fn_t previous_reduce_scatter_block_init; - mca_coll_base_module_t *previous_reduce_scatter_block_init_module; - mca_coll_base_module_reduce_scatter_init_fn_t previous_reduce_scatter_init; - mca_coll_base_module_t *previous_reduce_scatter_init_module; - mca_coll_base_module_scatterv_init_fn_t previous_scatterv_init; - mca_coll_base_module_t *previous_scatterv_init_module; - mca_coll_base_module_scatter_init_fn_t previous_scatter_init; - mca_coll_base_module_t *previous_scatter_init_module; + mca_coll_base_module_t* previous_reduce_scatter_block_init_module; + mca_coll_base_module_reduce_scatter_init_fn_t previous_reduce_scatter_init; + mca_coll_base_module_t* previous_reduce_scatter_init_module; + mca_coll_base_module_scatterv_init_fn_t previous_scatterv_init; + mca_coll_base_module_t* previous_scatterv_init_module; + mca_coll_base_module_scatter_init_fn_t previous_scatter_init; + mca_coll_base_module_t* previous_scatter_init_module; }; typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t; OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t); diff --git a/ompi/mca/coll/ucc/coll_ucc_allgather.c b/ompi/mca/coll/ucc/coll_ucc_allgather.c index f236ba401d2..2738468977c 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgather.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgather.c @@ -11,10 +11,11 @@ #include "coll_ucc_common.h" static inline ucc_status_t -mca_coll_ucc_allgather_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, - bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +mca_coll_ucc_allgather_init_common(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -79,8 +80,9 @@ int mca_coll_ucc_allgather(const void *sbuf, size_t scount, struct ompi_datatype ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc allgather"); - COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, false, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_allgather_init_common(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -102,8 +104,9 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp UCC_VERBOSE(3, "running ucc iallgather"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, false, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_allgather_init_common(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -125,10 +128,11 @@ int mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_dat ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "allgather_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, true, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_allgather_init_common(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + true, ucc_module, &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c index b93ded16426..a2958496c70 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c @@ -11,11 +11,13 @@ #include "coll_ucc_common.h" static inline ucc_status_t -mca_coll_ucc_allgatherv_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, - struct ompi_datatype_t *rdtype, bool persistent, - mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +mca_coll_ucc_allgatherv_init_common(const void *sbuf, size_t scount, + struct ompi_datatype_t *sdtype, + void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, + struct ompi_datatype_t *rdtype, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -36,7 +38,8 @@ mca_coll_ucc_allgatherv_iniz(const void *sbuf, size_t scount, struct ompi_dataty flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | (ompi_disp_array_is_64bit(rdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) | - (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0); + (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); ucc_coll_args_t coll = { .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, @@ -57,10 +60,6 @@ mca_coll_ucc_allgatherv_iniz(const void *sbuf, size_t scount, struct ompi_dataty } }; - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -79,8 +78,9 @@ int mca_coll_ucc_allgatherv(const void *sbuf, size_t scount, UCC_VERBOSE(3, "running ucc allgatherv"); - COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, - false, ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init_common(sbuf, scount, sdtype, + rbuf, rcounts, rdisps, rdtype, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -105,8 +105,9 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount, UCC_VERBOSE(3, "running ucc iallgatherv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, - false, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init_common(sbuf, scount, sdtype, + rbuf, rcounts, rdisps, rdtype, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -130,10 +131,11 @@ int mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_da ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "allgatherv_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, - true, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init_common(sbuf, scount, sdtype, + rbuf, rcounts, rdisps, rdtype, + true, ucc_module, &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_allreduce.c b/ompi/mca/coll/ucc/coll_ucc_allreduce.c index 65c68b60eb6..bad9ad1f886 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allreduce.c +++ b/ompi/mca/coll/ucc/coll_ucc_allreduce.c @@ -10,12 +10,11 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_allreduce_iniz(const void *sbuf, void *rbuf, size_t count, - struct ompi_datatype_t *dtype, - struct ompi_op_t *op, bool persistent, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t mca_coll_ucc_allreduce_init_common(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; @@ -73,8 +72,8 @@ int mca_coll_ucc_allreduce(const void *sbuf, void *rbuf, size_t count, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc allreduce"); - COLL_UCC_CHECK( - mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, false, ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_allreduce_init_common(sbuf, rbuf, count, dtype, op, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -96,8 +95,8 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, size_t count, UCC_VERBOSE(3, "running ucc iallreduce"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, false, ucc_module, - &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_allreduce_init_common(sbuf, rbuf, count, dtype, op, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -119,10 +118,10 @@ int mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, size_t count, ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "allreduce_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, true, ucc_module, &req, - coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_allreduce_init_common(sbuf, rbuf, count, dtype, op, + true, ucc_module, &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoall.c b/ompi/mca/coll/ucc/coll_ucc_alltoall.c index 8f466b6fcf1..da56ce24502 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoall.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoall.c @@ -11,10 +11,11 @@ #include "coll_ucc_common.h" static inline ucc_status_t -mca_coll_ucc_alltoall_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, - bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +mca_coll_ucc_alltoall_init_common(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -79,15 +80,16 @@ int mca_coll_ucc_alltoall(const void *sbuf, size_t scount, struct ompi_datatype_ ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc alltoall"); - COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, false, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_alltoall_init_common(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; fallback: UCC_VERBOSE(3, "running fallback alltoall"); return ucc_module->previous_alltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype, - comm, ucc_module->previous_alltoall_module); + comm, ucc_module->previous_alltoall_module); } int mca_coll_ucc_ialltoall(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, @@ -102,8 +104,9 @@ int mca_coll_ucc_ialltoall(const void *sbuf, size_t scount, struct ompi_datatype UCC_VERBOSE(3, "running ucc ialltoall"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, false, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_alltoall_init_common(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -125,10 +128,11 @@ int mca_coll_ucc_alltoall_init(const void *sbuf, size_t scount, struct ompi_data ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "alltoall_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, true, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_alltoall_init_common(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + true, ucc_module, &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c index 0f531061676..ce9b7e03fee 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c @@ -11,11 +11,13 @@ #include "coll_ucc_common.h" static inline ucc_status_t -mca_coll_ucc_alltoallv_iniz(const void *sbuf, ompi_count_array_t scounts, ompi_disp_array_t sdisps, - struct ompi_datatype_t *sdtype, void *rbuf, ompi_count_array_t rcounts, - ompi_disp_array_t rdisps, struct ompi_datatype_t *rdtype, - bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +mca_coll_ucc_alltoallv_init_common(const void *sbuf, ompi_count_array_t scounts, + ompi_disp_array_t sdisps, struct ompi_datatype_t *sdtype, + void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, + struct ompi_datatype_t *rdtype, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -37,7 +39,8 @@ mca_coll_ucc_alltoallv_iniz(const void *sbuf, ompi_count_array_t scounts, ompi_d /* Assumes that send counts/displs and recv counts/displs are both 32-bit or both 64-bit */ flags = (ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | (ompi_disp_array_is_64bit(sdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) | - (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0); + (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); ucc_coll_args_t coll = { .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, @@ -59,10 +62,6 @@ mca_coll_ucc_alltoallv_iniz(const void *sbuf, ompi_count_array_t scounts, ompi_d } }; - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -81,8 +80,9 @@ int mca_coll_ucc_alltoallv(const void *sbuf, ompi_count_array_t scounts, UCC_VERBOSE(3, "running ucc alltoallv"); - COLL_UCC_CHECK(mca_coll_ucc_alltoallv_iniz(sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, - rdtype, false, ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_alltoallv_init_common(sbuf, scounts, sdisps, sdtype, + rbuf, rcounts, rdisps, rdtype, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -107,8 +107,9 @@ int mca_coll_ucc_ialltoallv(const void *sbuf, ompi_count_array_t scounts, UCC_VERBOSE(3, "running ucc ialltoallv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_alltoallv_iniz(sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, - rdtype, false, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_alltoallv_init_common(sbuf, scounts, sdisps, sdtype, + rbuf, rcounts, rdisps, rdtype, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -118,7 +119,7 @@ int mca_coll_ucc_ialltoallv(const void *sbuf, ompi_count_array_t scounts, mca_coll_ucc_req_free((ompi_request_t **)&coll_req); } return ucc_module->previous_ialltoallv(sbuf, scounts, sdisps, sdtype, - rbuf, rcounts, rdisps, rdtype, + rbuf, rcounts, rdisps, rdtype, comm, request, ucc_module->previous_ialltoallv_module); } @@ -133,10 +134,11 @@ int mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_count_array_t scounts, ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "alltoallv_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_alltoallv_iniz(sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, - rdtype, true, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_alltoallv_init_common(sbuf, scounts, sdisps, sdtype, + rbuf, rcounts, rdisps, rdtype, + true, ucc_module, &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_barrier.c b/ompi/mca/coll/ucc/coll_ucc_barrier.c index 9449332358c..7deab58b4b0 100644 --- a/ompi/mca/coll/ucc/coll_ucc_barrier.c +++ b/ompi/mca/coll/ucc/coll_ucc_barrier.c @@ -9,10 +9,9 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_barrier_iniz(bool persistent, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t mca_coll_ucc_barrier_init_common(bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_coll_args_t coll = { .mask = 0, @@ -37,7 +36,7 @@ int mca_coll_ucc_barrier(struct ompi_communicator_t *comm, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc barrier"); - COLL_UCC_CHECK(mca_coll_ucc_barrier_iniz(false, ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_barrier_init_common(false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -56,7 +55,7 @@ int mca_coll_ucc_ibarrier(struct ompi_communicator_t *comm, UCC_VERBOSE(3, "running ucc ibarrier"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_barrier_iniz(false, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_barrier_init_common(false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -76,9 +75,9 @@ int mca_coll_ucc_barrier_init(struct ompi_communicator_t *comm, struct ompi_info ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "barrier_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_barrier_iniz(true, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_barrier_init_common(true, ucc_module, &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_bcast.c b/ompi/mca/coll/ucc/coll_ucc_bcast.c index 15c05e9acf3..f32771fd965 100644 --- a/ompi/mca/coll/ucc/coll_ucc_bcast.c +++ b/ompi/mca/coll/ucc/coll_ucc_bcast.c @@ -10,9 +10,10 @@ #include "coll_ucc_common.h" static inline ucc_status_t -mca_coll_ucc_bcast_iniz(void *buf, size_t count, struct ompi_datatype_t *dtype, int root, - bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +mca_coll_ucc_bcast_init_common(void *buf, size_t count, struct ompi_datatype_t *dtype, + int root, bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt = ompi_dtype_to_ucc_dtype(dtype); if (COLL_UCC_DT_UNSUPPORTED == ucc_dt) { @@ -50,14 +51,15 @@ int mca_coll_ucc_bcast(void *buf, size_t count, struct ompi_datatype_t *dtype, mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module; ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc bcast"); - COLL_UCC_CHECK(mca_coll_ucc_bcast_iniz(buf, count, dtype, root, false, ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_bcast_init_common(buf, count, dtype, root, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; fallback: UCC_VERBOSE(3, "running fallback bcast"); return ucc_module->previous_bcast(buf, count, dtype, root, - comm, ucc_module->previous_bcast_module); + comm, ucc_module->previous_bcast_module); } int mca_coll_ucc_ibcast(void *buf, size_t count, struct ompi_datatype_t *dtype, @@ -71,8 +73,8 @@ int mca_coll_ucc_ibcast(void *buf, size_t count, struct ompi_datatype_t *dtype, UCC_VERBOSE(3, "running ucc ibcast"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK( - mca_coll_ucc_bcast_iniz(buf, count, dtype, root, false, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_bcast_init_common(buf, count, dtype, root, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -93,10 +95,10 @@ int mca_coll_ucc_bcast_init(void *buf, size_t count, struct ompi_datatype_t *dty ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "bcast_init init %p", coll_req); - COLL_UCC_CHECK( - mca_coll_ucc_bcast_iniz(buf, count, dtype, root, true, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_bcast_init_common(buf, count, dtype, root, + true, ucc_module, &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_common.h b/ompi/mca/coll/ucc/coll_ucc_common.h index 73eaed79e30..09bd9359a1e 100644 --- a/ompi/mca/coll/ucc/coll_ucc_common.h +++ b/ompi/mca/coll/ucc/coll_ucc_common.h @@ -43,7 +43,7 @@ _coll_req->super.req_type = OMPI_REQUEST_COLL; \ } while(0) -#define COLL_UCC_GET_REQ_PC(_coll_req) \ +#define COLL_UCC_GET_REQ_PERSISTENT(_coll_req) \ do { \ opal_free_list_item_t *item; \ item = opal_free_list_wait(&mca_coll_ucc_component.requests); \ diff --git a/ompi/mca/coll/ucc/coll_ucc_component.c b/ompi/mca/coll/ucc/coll_ucc_component.c index 9ceebb3fcd7..4fde1e0a999 100644 --- a/ompi/mca/coll/ucc/coll_ucc_component.c +++ b/ompi/mca/coll/ucc/coll_ucc_component.c @@ -145,25 +145,20 @@ static ucc_coll_type_t mca_coll_ucc_str_to_type(const char *str) } /* is a persistent collective */ -static inline int mca_coll_ucc_init_cts_is_p(const char *cp, char *bp, size_t bz) +static inline int mca_coll_ucc_init_cts_is_persistent(const char *cp, char *bp, size_t bz) { size_t len = strlen(cp), len_suffix = sizeof("_init") - 1; - if ((bz > 0) && (bp != 0)) { - bp[0] = '\0'; - } + assert((bz > 0) && (bp != 0)); /* check if it is a persistent collective */ if (len > len_suffix) { size_t blen = len - len_suffix; const char *cp_suffix = &cp[blen]; if (0 == strcmp(cp_suffix, "_init")) { - if ((bz > 0) && (bp != 0)) { - if (blen >= bz) { - return 0 /* XXX internal error */; - } - strncpy(bp, cp, blen); - bp[blen] = '\0'; + int wc = snprintf(bp, bz, "%*.*s", (int)blen, (int)blen, cp); + if ((wc < 0) || ((size_t)wc >= bz)) { + return -1 /* XXX internal error */; } return 1 /* true */; } @@ -172,8 +167,8 @@ static inline int mca_coll_ucc_init_cts_is_p(const char *cp, char *bp, size_t bz } /* is an alias (special) name */ -static inline int mca_coll_ucc_init_cts_is_a(const char *cp, bool disable, - mca_coll_ucc_component_t *cm) +static inline int mca_coll_ucc_init_cts_is_alias(const char *cp, bool disable, + mca_coll_ucc_component_t *cm) { if (0 == strcmp(cp, "colls_b")) { /* all blocking colls */ if (disable) { @@ -192,9 +187,9 @@ static inline int mca_coll_ucc_init_cts_is_a(const char *cp, bool disable, return 1 /* true */; } else if (0 == strcmp(cp, "colls_p")) { /* all persistent colls */ if (disable) { - cm->pc_cts_requested &= ~COLL_UCC_CTS; + cm->ps_cts_requested &= ~COLL_UCC_CTS; } else { - cm->pc_cts_requested |= COLL_UCC_CTS; + cm->ps_cts_requested |= COLL_UCC_CTS; } return 1 /* true */; } @@ -215,33 +210,29 @@ static void mca_coll_ucc_init_default_cts(void) n_cts = opal_argv_count(cts); cm->cts_requested = disable ? COLL_UCC_CTS : 0; cm->nb_cts_requested = disable ? COLL_UCC_CTS : 0; - cm->pc_cts_requested = 0; /* XXX PC currently disabled by default */ + cm->ps_cts_requested = disable ? COLL_UCC_CTS : 0; for (i = 0; i < n_cts; i++) { char l_str[64]; /* XXX sizeof("reduce_scatter_block") */ size_t l_stz = sizeof(l_str); - if (0 < mca_coll_ucc_init_cts_is_a(cts[i], disable, cm)) { + if (0 < mca_coll_ucc_init_cts_is_alias(cts[i], disable, cm)) { continue; } if (('i' == cts[i][0]) || ('I' == cts[i][0])) { /* non blocking collective setting */ str = cts[i] + 1; ct = &cm->nb_cts_requested; - } else if (0 < mca_coll_ucc_init_cts_is_p(cts[i], l_str, l_stz)) { + } else if (0 < mca_coll_ucc_init_cts_is_persistent(cts[i], l_str, l_stz)) { /* persistent collective setting */ str = l_str; - ct = &cm->pc_cts_requested; + ct = &cm->ps_cts_requested; } else { str = cts[i]; ct = &cm->cts_requested; } c = mca_coll_ucc_str_to_type(str); if (UCC_COLL_TYPE_LAST == c) { - if (&cm->pc_cts_requested != ct) { - *ct = COLL_UCC_CTS; - } else { - *ct = 0; /* XXX PC currently disabled by default */ - } + *ct = COLL_UCC_CTS; break; } if (disable) { diff --git a/ompi/mca/coll/ucc/coll_ucc_gather.c b/ompi/mca/coll/ucc/coll_ucc_gather.c index 6deecb317ae..ebbcfe0adc9 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gather.c +++ b/ompi/mca/coll/ucc/coll_ucc_gather.c @@ -12,10 +12,11 @@ #include "coll_ucc_common.h" static inline ucc_status_t -mca_coll_ucc_gather_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, - bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +mca_coll_ucc_gather_init_common(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + int root, bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -95,8 +96,9 @@ int mca_coll_ucc_gather(const void *sbuf, size_t scount, struct ompi_datatype_t ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc gather"); - COLL_UCC_CHECK(mca_coll_ucc_gather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, false, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_gather_init_common(sbuf, scount, sdtype, rbuf, rcount, + rdtype, root, false, ucc_module, + &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -119,8 +121,9 @@ int mca_coll_ucc_igather(const void *sbuf, size_t scount, struct ompi_datatype_t UCC_VERBOSE(3, "running ucc igather"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_gather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, false, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_gather_init_common(sbuf, scount, sdtype, rbuf, rcount, + rdtype, root, false, ucc_module, + &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -143,10 +146,11 @@ int mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct ompi_dataty ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "gather_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_gather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, true, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_gather_init_common(sbuf, scount, sdtype, rbuf, rcount, + rdtype, root, true, ucc_module, + &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_gatherv.c b/ompi/mca/coll/ucc/coll_ucc_gatherv.c index 5bc0a18279b..abbdde5a77b 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_gatherv.c @@ -12,11 +12,12 @@ #include "coll_ucc_common.h" static inline ucc_status_t -mca_coll_ucc_gatherv_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, - struct ompi_datatype_t *rdtype, int root, bool persistent, - mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +mca_coll_ucc_gatherv_init_common(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, + struct ompi_datatype_t *rdtype, int root, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -46,7 +47,8 @@ mca_coll_ucc_gatherv_iniz(const void *sbuf, size_t scount, struct ompi_datatype_ flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | (ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) | - (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0); + (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); ucc_coll_args_t coll = { .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, @@ -68,10 +70,6 @@ mca_coll_ucc_gatherv_iniz(const void *sbuf, size_t scount, struct ompi_datatype_ }, }; - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -88,8 +86,9 @@ int mca_coll_ucc_gatherv(const void *sbuf, size_t scount, struct ompi_datatype_t ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc gatherv"); - COLL_UCC_CHECK(mca_coll_ucc_gatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, - root, false, ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_gatherv_init_common(sbuf, scount, sdtype, rbuf, rcounts, + disps, rdtype, root, false, ucc_module, + &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -113,8 +112,9 @@ int mca_coll_ucc_igatherv(const void *sbuf, size_t scount, struct ompi_datatype_ UCC_VERBOSE(3, "running ucc igatherv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_gatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, - root, false, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_gatherv_init_common(sbuf, scount, sdtype, rbuf, rcounts, + disps, rdtype, root, false, ucc_module, + &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -138,10 +138,11 @@ int mca_coll_ucc_gatherv_init(const void *sbuf, size_t scount, struct ompi_datat ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "gatherv_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_gatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, - root, true, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_gatherv_init_common(sbuf, scount, sdtype, rbuf, rcounts, + disps, rdtype, root, true, ucc_module, + &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_module.c b/ompi/mca/coll/ucc/coll_ucc_module.c index c3c0b11e3ec..38901bc1403 100644 --- a/ompi/mca/coll/ucc/coll_ucc_module.c +++ b/ompi/mca/coll/ucc/coll_ucc_module.c @@ -88,34 +88,34 @@ static void mca_coll_ucc_module_clear(mca_coll_ucc_module_t *ucc_module) ucc_module->previous_scatter_module = NULL; ucc_module->previous_iscatter = NULL; ucc_module->previous_iscatter_module = NULL; - ucc_module->previous_allreduce_init = NULL; - ucc_module->previous_allreduce_init_module = NULL; - ucc_module->previous_barrier_init = NULL; - ucc_module->previous_barrier_init_module = NULL; - ucc_module->previous_bcast_init = NULL; - ucc_module->previous_bcast_init_module = NULL; - ucc_module->previous_alltoall_init = NULL; - ucc_module->previous_alltoall_init_module = NULL; - ucc_module->previous_alltoallv_init = NULL; - ucc_module->previous_alltoallv_init_module = NULL; - ucc_module->previous_allgather_init = NULL; - ucc_module->previous_allgather_init_module = NULL; - ucc_module->previous_allgatherv_init = NULL; - ucc_module->previous_allgatherv_init_module = NULL; - ucc_module->previous_reduce_init = NULL; - ucc_module->previous_reduce_init_module = NULL; - ucc_module->previous_gather_init = NULL; - ucc_module->previous_gather_init_module = NULL; - ucc_module->previous_gatherv_init = NULL; - ucc_module->previous_gatherv_init_module = NULL; - ucc_module->previous_reduce_scatter_block_init = NULL; + ucc_module->previous_allreduce_init = NULL; + ucc_module->previous_allreduce_init_module = NULL; + ucc_module->previous_barrier_init = NULL; + ucc_module->previous_barrier_init_module = NULL; + ucc_module->previous_bcast_init = NULL; + ucc_module->previous_bcast_init_module = NULL; + ucc_module->previous_alltoall_init = NULL; + ucc_module->previous_alltoall_init_module = NULL; + ucc_module->previous_alltoallv_init = NULL; + ucc_module->previous_alltoallv_init_module = NULL; + ucc_module->previous_allgather_init = NULL; + ucc_module->previous_allgather_init_module = NULL; + ucc_module->previous_allgatherv_init = NULL; + ucc_module->previous_allgatherv_init_module = NULL; + ucc_module->previous_reduce_init = NULL; + ucc_module->previous_reduce_init_module = NULL; + ucc_module->previous_gather_init = NULL; + ucc_module->previous_gather_init_module = NULL; + ucc_module->previous_gatherv_init = NULL; + ucc_module->previous_gatherv_init_module = NULL; + ucc_module->previous_reduce_scatter_block_init = NULL; ucc_module->previous_reduce_scatter_block_init_module = NULL; - ucc_module->previous_reduce_scatter_init = NULL; - ucc_module->previous_reduce_scatter_init_module = NULL; - ucc_module->previous_scatterv_init = NULL; - ucc_module->previous_scatterv_init_module = NULL; - ucc_module->previous_scatter_init = NULL; - ucc_module->previous_scatter_init_module = NULL; + ucc_module->previous_reduce_scatter_init = NULL; + ucc_module->previous_reduce_scatter_init_module = NULL; + ucc_module->previous_scatterv_init = NULL; + ucc_module->previous_scatterv_init_module = NULL; + ucc_module->previous_scatter_init = NULL; + ucc_module->previous_scatter_init_module = NULL; } static void mca_coll_ucc_module_construct(mca_coll_ucc_module_t *ucc_module) @@ -425,7 +425,7 @@ static inline ucc_ep_map_t get_rank_map(struct ompi_communicator_t *comm) MCA_COLL_INSTALL_API(__comm, i##__api, mca_coll_ucc_i##__api, &__ucc_module->super, "ucc"); \ (__ucc_module)->super.coll_i##__api = mca_coll_ucc_i##__api; \ } \ - if (mca_coll_ucc_component.pc_cts_requested & UCC_COLL_TYPE_##__COLL) \ + if (mca_coll_ucc_component.ps_cts_requested & UCC_COLL_TYPE_##__COLL) \ { \ MCA_COLL_SAVE_API(__comm, __api##_init, (__ucc_module)->previous_##__api##_init, (__ucc_module)->previous_##__api##_init_module, "ucc"); \ MCA_COLL_INSTALL_API(__comm, __api##_init, mca_coll_ucc_##__api##_init, &__ucc_module->super, "ucc"); \ @@ -565,17 +565,17 @@ mca_coll_ucc_module_disable(mca_coll_base_module_t *module, UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce); UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce); UCC_UNINSTALL_COLL_API(comm, ucc_module, gather); - /* UCC_UNINSTALL_COLL_API(comm, ucc_module, igather); */ + UCC_UNINSTALL_COLL_API(comm, ucc_module, igather); UCC_UNINSTALL_COLL_API(comm, ucc_module, gatherv); - /* UCC_UNINSTALL_COLL_API(comm, ucc_module, igatherv); */ + UCC_UNINSTALL_COLL_API(comm, ucc_module, igatherv); UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter_block); - /* UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce_scatter_block); */ + UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce_scatter_block); UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter); - /* UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce_scatter); */ + UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce_scatter); UCC_UNINSTALL_COLL_API(comm, ucc_module, scatter); - /* UCC_UNINSTALL_COLL_API(comm, ucc_module, iscatter); */ + UCC_UNINSTALL_COLL_API(comm, ucc_module, iscatter); UCC_UNINSTALL_COLL_API(comm, ucc_module, scatterv); - /* UCC_UNINSTALL_COLL_API(comm, ucc_module, iscatterv); */ + UCC_UNINSTALL_COLL_API(comm, ucc_module, iscatterv); UCC_UNINSTALL_COLL_API(comm, ucc_module, allreduce_init); UCC_UNINSTALL_COLL_API(comm, ucc_module, barrier_init); @@ -648,7 +648,7 @@ OBJ_CLASS_INSTANCE(mca_coll_ucc_req_t, ompi_request_t, int mca_coll_ucc_req_free(struct ompi_request_t **ompi_req) { - if (MPI_REQUEST_NULL != ompi_req[0]) { + { mca_coll_ucc_req_t *coll_req = (mca_coll_ucc_req_t *) ompi_req[0]; if (true == coll_req->super.req_persistent) { UCC_VERBOSE(5, "%s free %p", "_init", coll_req); @@ -694,6 +694,10 @@ int mca_coll_ucc_req_start(size_t count, struct ompi_request_t **requests) continue; } if (true != coll_req->super.req_persistent) { + coll_req->super.req_status.MPI_ERROR = MPI_ERR_REQUEST; + if (OMPI_SUCCESS == rc) { + rc = OMPI_ERROR; + } continue; } UCC_VERBOSE(5, "%s post %p", "_init", coll_req); @@ -710,7 +714,7 @@ int mca_coll_ucc_req_start(size_t count, struct ompi_request_t **requests) if (UCC_OK != rc_ucc) { UCC_ERROR("ucc_collective_post failed: %s", ucc_status_string(rc_ucc)); coll_req->super.req_complete = REQUEST_COMPLETED; - coll_req->super.req_state = OMPI_REQUEST_INACTIVE; + coll_req->super.req_status.MPI_ERROR = MPI_ERR_INTERN; if (OMPI_SUCCESS == rc) { rc = OMPI_ERROR; } diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce.c b/ompi/mca/coll/ucc/coll_ucc_reduce.c index ed719525334..305b0f9c8a9 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce.c @@ -9,12 +9,12 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_reduce_iniz(const void *sbuf, void *rbuf, size_t count, - struct ompi_datatype_t *dtype, - struct ompi_op_t *op, int root, bool persistent, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t mca_coll_ucc_reduce_init_common(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, int root, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; @@ -74,8 +74,8 @@ int mca_coll_ucc_reduce(const void *sbuf, void* rbuf, size_t count, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc reduce"); - COLL_UCC_CHECK(mca_coll_ucc_reduce_iniz(sbuf, rbuf, count, dtype, op, root, false, ucc_module, - &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_init_common(sbuf, rbuf, count, dtype, op, + root, false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -98,8 +98,8 @@ int mca_coll_ucc_ireduce(const void *sbuf, void* rbuf, size_t count, UCC_VERBOSE(3, "running ucc ireduce"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_iniz(sbuf, rbuf, count, dtype, op, root, false, ucc_module, - &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_init_common(sbuf, rbuf, count, dtype, op, root, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -121,10 +121,10 @@ int mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf, size_t count, ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "reduce_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_iniz(sbuf, rbuf, count, dtype, op, root, true, ucc_module, - &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_init_common(sbuf, rbuf, count, dtype, op, root, + true, ucc_module, &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c index fe1a591b474..7ba6effb774 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c @@ -11,10 +11,11 @@ #include "coll_ucc_common.h" static inline ucc_status_t -mca_coll_ucc_reduce_scatter_iniz(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, - struct ompi_datatype_t *dtype, struct ompi_op_t *op, - bool persistent, mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) +mca_coll_ucc_reduce_scatter_init_common(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; @@ -47,7 +48,8 @@ mca_coll_ucc_reduce_scatter_iniz(const void *sbuf, void *rbuf, ompi_count_array_ total_count += ompi_count_array_get(rcounts, i); } - flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0); + flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); ucc_coll_args_t coll = { .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, @@ -67,11 +69,6 @@ mca_coll_ucc_reduce_scatter_iniz(const void *sbuf, void *rbuf, ompi_count_array_ }, .op = ucc_op, }; - - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -88,8 +85,8 @@ int mca_coll_ucc_reduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_t ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc reduce_scatter"); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_iniz(sbuf, rbuf, rcounts, dtype, op, false, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_init_common(sbuf, rbuf, rcounts, dtype, + op, false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -113,8 +110,8 @@ int mca_coll_ucc_ireduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_ UCC_VERBOSE(3, "running ucc ireduce_scatter"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_iniz(sbuf, rbuf, rcounts, dtype, op, false, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_init_common(sbuf, rbuf, rcounts, dtype, + op, false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -137,10 +134,10 @@ int mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi_count_ar ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "reduce_scatter_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_iniz(sbuf, rbuf, rcounts, dtype, op, true, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_init_common(sbuf, rbuf, rcounts, dtype, + op, true, ucc_module, &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c index 1d88647a24a..93b31a30ee8 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c @@ -11,10 +11,13 @@ #include "coll_ucc_common.h" static inline ucc_status_t -mca_coll_ucc_reduce_scatter_block_iniz(const void *sbuf, void *rbuf, size_t rcount, - struct ompi_datatype_t *dtype, struct ompi_op_t *op, - bool persistent, mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) +mca_coll_ucc_reduce_scatter_block_init_common(const void *sbuf, void *rbuf, + size_t rcount, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; @@ -77,8 +80,9 @@ int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, size_t rcoun ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc reduce scatter block"); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_iniz(sbuf, rbuf, rcount, dtype, op, false, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init_common(sbuf, rbuf, rcount, + dtype, op, false, ucc_module, + &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -102,8 +106,9 @@ int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, size_t rcou UCC_VERBOSE(3, "running ucc ireduce_scatter_block"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_iniz(sbuf, rbuf, rcount, dtype, op, false, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init_common(sbuf, rbuf, rcount, + dtype, op, false, ucc_module, + &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -127,10 +132,11 @@ int mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf, size_t ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "reduce_scatter_block_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_iniz(sbuf, rbuf, rcount, dtype, op, true, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init_common(sbuf, rbuf, rcount, + dtype, op, true, ucc_module, + &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_scatter.c b/ompi/mca/coll/ucc/coll_ucc_scatter.c index bf621a3f01a..0058f94e7ff 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatter.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatter.c @@ -11,10 +11,13 @@ #include "coll_ucc_common.h" static inline ucc_status_t -mca_coll_ucc_scatter_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, - bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +mca_coll_ucc_scatter_init_common(const void *sbuf, size_t scount, + struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, + struct ompi_datatype_t *rdtype, int root, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == rbuf); @@ -95,8 +98,9 @@ int mca_coll_ucc_scatter(const void *sbuf, size_t scount, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc scatter"); - COLL_UCC_CHECK(mca_coll_ucc_scatter_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, - false, ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_scatter_init_common(sbuf, scount, sdtype, rbuf, rcount, + rdtype, root, false, ucc_module, &req, + NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -121,8 +125,9 @@ int mca_coll_ucc_iscatter(const void *sbuf, size_t scount, UCC_VERBOSE(3, "running ucc iscatter"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_scatter_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, - false, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_scatter_init_common(sbuf, scount, sdtype, rbuf, rcount, + rdtype, root, false, ucc_module, &req, + coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -145,10 +150,11 @@ int mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, struct ompi_datat ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "scatter_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_scatter_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, true, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_scatter_init_common(sbuf, scount, sdtype, rbuf, rcount, + rdtype, root, true, ucc_module, &req, + coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_scatterv.c b/ompi/mca/coll/ucc/coll_ucc_scatterv.c index 268a2e4599a..c1a611afd53 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatterv.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatterv.c @@ -11,11 +11,13 @@ #include "coll_ucc_common.h" static inline ucc_status_t -mca_coll_ucc_scatterv_iniz(const void *sbuf, ompi_count_array_t scounts, ompi_disp_array_t disps, - struct ompi_datatype_t *sdtype, void *rbuf, size_t rcount, - struct ompi_datatype_t *rdtype, int root, bool persistent, - mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +mca_coll_ucc_scatterv_init_common(const void *sbuf, ompi_count_array_t scounts, + ompi_disp_array_t disps, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, + struct ompi_datatype_t *rdtype, int root, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == rbuf); @@ -45,7 +47,8 @@ mca_coll_ucc_scatterv_iniz(const void *sbuf, ompi_count_array_t scounts, ompi_di flags = (ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | (ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) | - (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0); + (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); ucc_coll_args_t coll = { .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, @@ -67,10 +70,6 @@ mca_coll_ucc_scatterv_iniz(const void *sbuf, ompi_count_array_t scounts, ompi_di }, }; - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -88,8 +87,9 @@ int mca_coll_ucc_scatterv(const void *sbuf, ompi_count_array_t scounts, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc scatterv"); - COLL_UCC_CHECK(mca_coll_ucc_scatterv_iniz(sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, - root, false, ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_scatterv_init_common(sbuf, scounts, disps, sdtype, + rbuf, rcount, rdtype, root, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -114,8 +114,9 @@ int mca_coll_ucc_iscatterv(const void *sbuf, ompi_count_array_t scounts, UCC_VERBOSE(3, "running ucc iscatterv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_scatterv_iniz(sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, - root, false, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_scatterv_init_common(sbuf, scounts, disps, sdtype, + rbuf, rcount, rdtype, root, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -139,10 +140,11 @@ int mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t scounts, ucc_coll_req_h req; mca_coll_ucc_req_t *coll_req = NULL; - COLL_UCC_GET_REQ_PC(coll_req); + COLL_UCC_GET_REQ_PERSISTENT(coll_req); UCC_VERBOSE(3, "scatterv_init init %p", coll_req); - COLL_UCC_CHECK(mca_coll_ucc_scatterv_iniz(sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, - root, true, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_scatterv_init_common(sbuf, scounts, disps, sdtype, + rbuf, rcount, rdtype, root, + true, ucc_module, &req, coll_req)); *request = &coll_req->super; return OMPI_SUCCESS; fallback: From a8b244e91e0c7b4caa83cd738cfb81ac67992536 Mon Sep 17 00:00:00 2001 From: "hasegawa.kento" Date: Fri, 5 Sep 2025 14:37:38 +0900 Subject: [PATCH 5/5] COLL/UCC: Standardize flag and mask value assignment Signed-off-by: hasegawa.kento --- ompi/mca/coll/ucc/coll_ucc_allgather.c | 16 ++++++---------- ompi/mca/coll/ucc/coll_ucc_allreduce.c | 18 ++++++++---------- ompi/mca/coll/ucc/coll_ucc_alltoall.c | 16 ++++++---------- ompi/mca/coll/ucc/coll_ucc_barrier.c | 12 ++++++------ ompi/mca/coll/ucc/coll_ucc_bcast.c | 12 ++++++------ ompi/mca/coll/ucc/coll_ucc_gather.c | 16 ++++++---------- ompi/mca/coll/ucc/coll_ucc_reduce.c | 18 ++++++++---------- .../coll/ucc/coll_ucc_reduce_scatter_block.c | 12 ++++++------ ompi/mca/coll/ucc/coll_ucc_scatter.c | 16 ++++++---------- 9 files changed, 58 insertions(+), 78 deletions(-) diff --git a/ompi/mca/coll/ucc/coll_ucc_allgather.c b/ompi/mca/coll/ucc/coll_ucc_allgather.c index 2738468977c..2362cc038a1 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgather.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgather.c @@ -20,6 +20,7 @@ mca_coll_ucc_allgather_init_common(const void *sbuf, size_t scount, struct ompi_ ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_size = ompi_comm_size(ucc_module->comm); + uint64_t flags = 0; if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) || !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { @@ -39,9 +40,12 @@ mca_coll_ucc_allgather_init_common(const void *sbuf, size_t scount, struct ompi_ goto fallback; } + flags = (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_ALLGATHER, .src.info = { .buffer = (void*)sbuf, @@ -57,14 +61,6 @@ mca_coll_ucc_allgather_init_common(const void *sbuf, size_t scount, struct ompi_ } }; - if (is_inplace) { - coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_allreduce.c b/ompi/mca/coll/ucc/coll_ucc_allreduce.c index bad9ad1f886..ac8c990a939 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allreduce.c +++ b/ompi/mca/coll/ucc/coll_ucc_allreduce.c @@ -18,6 +18,7 @@ static inline ucc_status_t mca_coll_ucc_allreduce_init_common(const void *sbuf, { ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; + uint64_t flags = 0; ucc_dt = ompi_dtype_to_ucc_dtype(dtype); ucc_op = ompi_op_to_ucc_op(op); @@ -31,9 +32,13 @@ static inline ucc_status_t mca_coll_ucc_allreduce_init_common(const void *sbuf, op->o_name); goto fallback; } + + flags = ((MPI_IN_PLACE == sbuf) ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_ALLREDUCE, .src.info = { .buffer = (void*)sbuf, @@ -49,14 +54,7 @@ static inline ucc_status_t mca_coll_ucc_allreduce_init_common(const void *sbuf, }, .op = ucc_op, }; - if (MPI_IN_PLACE == sbuf) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } + COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoall.c b/ompi/mca/coll/ucc/coll_ucc_alltoall.c index da56ce24502..f61171576b2 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoall.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoall.c @@ -20,6 +20,7 @@ mca_coll_ucc_alltoall_init_common(const void *sbuf, size_t scount, struct ompi_d ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_size = ompi_comm_size(ucc_module->comm); + uint64_t flags = 0; if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) || !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { @@ -39,9 +40,12 @@ mca_coll_ucc_alltoall_init_common(const void *sbuf, size_t scount, struct ompi_d goto fallback; } + flags = (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_ALLTOALL, .src.info = { .buffer = (void*)sbuf, @@ -57,14 +61,6 @@ mca_coll_ucc_alltoall_init_common(const void *sbuf, size_t scount, struct ompi_d } }; - if (is_inplace) { - coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_barrier.c b/ompi/mca/coll/ucc/coll_ucc_barrier.c index 7deab58b4b0..da886e56f54 100644 --- a/ompi/mca/coll/ucc/coll_ucc_barrier.c +++ b/ompi/mca/coll/ucc/coll_ucc_barrier.c @@ -13,16 +13,16 @@ static inline ucc_status_t mca_coll_ucc_barrier_init_common(bool persistent, mca ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { + uint64_t flags = 0; + + flags = (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_BARRIER }; - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_bcast.c b/ompi/mca/coll/ucc/coll_ucc_bcast.c index f32771fd965..8da3c839133 100644 --- a/ompi/mca/coll/ucc/coll_ucc_bcast.c +++ b/ompi/mca/coll/ucc/coll_ucc_bcast.c @@ -16,14 +16,18 @@ mca_coll_ucc_bcast_init_common(void *buf, size_t count, struct ompi_datatype_t * mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt = ompi_dtype_to_ucc_dtype(dtype); + uint64_t flags = 0; + if (COLL_UCC_DT_UNSUPPORTED == ucc_dt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", dtype->super.name); goto fallback; } + flags = (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_BCAST, .root = root, .src.info = { @@ -34,10 +38,6 @@ mca_coll_ucc_bcast_init_common(void *buf, size_t count, struct ompi_datatype_t * } }; - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_gather.c b/ompi/mca/coll/ucc/coll_ucc_gather.c index ebbcfe0adc9..ad03d654b4c 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gather.c +++ b/ompi/mca/coll/ucc/coll_ucc_gather.c @@ -22,6 +22,7 @@ mca_coll_ucc_gather_init_common(const void *sbuf, size_t scount, struct ompi_dat bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); int comm_size = ompi_comm_size(ucc_module->comm); + uint64_t flags = 0; if (comm_rank == root) { if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) || @@ -54,9 +55,12 @@ mca_coll_ucc_gather_init_common(const void *sbuf, size_t scount, struct ompi_dat } } + flags = (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_GATHER, .root = root, .src.info = { @@ -73,14 +77,6 @@ mca_coll_ucc_gather_init_common(const void *sbuf, size_t scount, struct ompi_dat }, }; - if (is_inplace) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce.c b/ompi/mca/coll/ucc/coll_ucc_reduce.c index 305b0f9c8a9..c76b16c8881 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce.c @@ -18,6 +18,7 @@ static inline ucc_status_t mca_coll_ucc_reduce_init_common(const void *sbuf, voi { ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; + uint64_t flags = 0; ucc_dt = ompi_dtype_to_ucc_dtype(dtype); ucc_op = ompi_op_to_ucc_op(op); @@ -31,9 +32,13 @@ static inline ucc_status_t mca_coll_ucc_reduce_init_common(const void *sbuf, voi op->o_name); goto fallback; } + + flags = ((MPI_IN_PLACE == sbuf) ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_REDUCE, .root = root, .src.info = { @@ -50,14 +55,7 @@ static inline ucc_status_t mca_coll_ucc_reduce_init_common(const void *sbuf, voi }, .op = ucc_op, }; - if (MPI_IN_PLACE == sbuf) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } + COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c index 93b31a30ee8..49deba9393e 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c @@ -22,6 +22,7 @@ mca_coll_ucc_reduce_scatter_block_init_common(const void *sbuf, void *rbuf, ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; int comm_size = ompi_comm_size(ucc_module->comm); + uint64_t flags = 0; if (MPI_IN_PLACE == sbuf) { /* TODO: UCC defines inplace differently: @@ -41,9 +42,12 @@ mca_coll_ucc_reduce_scatter_block_init_common(const void *sbuf, void *rbuf, op->o_name); goto fallback; } + + flags = (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_REDUCE_SCATTER, .src.info = { .buffer = (void*)sbuf, @@ -60,10 +64,6 @@ mca_coll_ucc_reduce_scatter_block_init_common(const void *sbuf, void *rbuf, .op = ucc_op, }; - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_scatter.c b/ompi/mca/coll/ucc/coll_ucc_scatter.c index 0058f94e7ff..4f4e60eaec3 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatter.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatter.c @@ -23,6 +23,7 @@ mca_coll_ucc_scatter_init_common(const void *sbuf, size_t scount, bool is_inplace = (MPI_IN_PLACE == rbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); int comm_size = ompi_comm_size(ucc_module->comm); + uint64_t flags = 0; if (comm_rank == root) { if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) || @@ -55,9 +56,12 @@ mca_coll_ucc_scatter_init_common(const void *sbuf, size_t scount, } } + flags = (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_SCATTER, .root = root, .src.info = { @@ -74,14 +78,6 @@ mca_coll_ucc_scatter_init_common(const void *sbuf, size_t scount, }, }; - if (is_inplace) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } - if (true == persistent) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: