From 5e05bae9faab4ec7925a37db530b70436b238b99 Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Mon, 7 Oct 2024 11:42:59 +0000 Subject: [PATCH] v4.1.x: coll/ucc: add reduce scatter block Signed-off-by: Sergey Lebedev bot:notacherrypick --- ompi/mca/coll/ucc/Makefile.am | 31 ++--- ompi/mca/coll/ucc/coll_ucc.h | 96 ++++++++------ ompi/mca/coll/ucc/coll_ucc_component.c | 2 + ompi/mca/coll/ucc/coll_ucc_module.c | 59 +++++---- .../coll/ucc/coll_ucc_reduce_scatter_block.c | 117 ++++++++++++++++++ 5 files changed, 226 insertions(+), 79 deletions(-) create mode 100644 ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c diff --git a/ompi/mca/coll/ucc/Makefile.am b/ompi/mca/coll/ucc/Makefile.am index c66ffa47694..a28957cc20c 100644 --- a/ompi/mca/coll/ucc/Makefile.am +++ b/ompi/mca/coll/ucc/Makefile.am @@ -12,21 +12,22 @@ AM_CPPFLAGS = $(coll_ucc_CPPFLAGS) -coll_ucc_sources = \ - coll_ucc.h \ - coll_ucc_debug.h \ - coll_ucc_dtypes.h \ - coll_ucc_common.h \ - coll_ucc_module.c \ - coll_ucc_component.c \ - coll_ucc_barrier.c \ - coll_ucc_bcast.c \ - coll_ucc_allreduce.c \ - coll_ucc_reduce.c \ - coll_ucc_alltoall.c \ - coll_ucc_alltoallv.c \ - coll_ucc_allgather.c \ - coll_ucc_allgatherv.c +coll_ucc_sources = \ + coll_ucc.h \ + coll_ucc_debug.h \ + coll_ucc_dtypes.h \ + coll_ucc_common.h \ + coll_ucc_module.c \ + coll_ucc_component.c \ + coll_ucc_barrier.c \ + coll_ucc_bcast.c \ + coll_ucc_allreduce.c \ + coll_ucc_reduce.c \ + coll_ucc_alltoall.c \ + coll_ucc_alltoallv.c \ + coll_ucc_allgather.c \ + coll_ucc_allgatherv.c \ + coll_ucc_reduce_scatter_block.c # Make the output library in this directory, and name it either # mca__.la (for DSO builds) or libmca__.la diff --git a/ompi/mca/coll/ucc/coll_ucc.h b/ompi/mca/coll/ucc/coll_ucc.h index 718d5411022..711a1f7f326 100644 --- a/ompi/mca/coll/ucc/coll_ucc.h +++ b/ompi/mca/coll/ucc/coll_ucc.h @@ -27,10 +27,11 @@ BEGIN_C_DECLS #define COLL_UCC_CTS (UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST | \ UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLTOALL | \ UCC_COLL_TYPE_ALLTOALLV | UCC_COLL_TYPE_ALLGATHER | \ - UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV) + UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV | \ + UCC_COLL_TYPE_REDUCE_SCATTER) -#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather,allgatherv,reduce," \ - "ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,iallgatherv,ireduce" +#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather,allgatherv,reduce,reduce_scatter_block," \ + "ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,iallgatherv,ireduce,ireduce_scatter_block" typedef struct mca_coll_ucc_req { ompi_request_t super; @@ -64,42 +65,46 @@ OMPI_MODULE_DECLSPEC extern mca_coll_ucc_component_t mca_coll_ucc_component; * UCC enabled communicator */ struct mca_coll_ucc_module_t { - mca_coll_base_module_t super; - ompi_communicator_t* comm; - int rank; - ucc_team_h ucc_team; - mca_coll_base_module_allreduce_fn_t previous_allreduce; - mca_coll_base_module_t* previous_allreduce_module; - mca_coll_base_module_iallreduce_fn_t previous_iallreduce; - mca_coll_base_module_t* previous_iallreduce_module; - mca_coll_base_module_reduce_fn_t previous_reduce; - mca_coll_base_module_t* previous_reduce_module; - mca_coll_base_module_ireduce_fn_t previous_ireduce; - mca_coll_base_module_t* previous_ireduce_module; - mca_coll_base_module_barrier_fn_t previous_barrier; - mca_coll_base_module_t* previous_barrier_module; - mca_coll_base_module_ibarrier_fn_t previous_ibarrier; - mca_coll_base_module_t* previous_ibarrier_module; - mca_coll_base_module_bcast_fn_t previous_bcast; - mca_coll_base_module_t* previous_bcast_module; - mca_coll_base_module_ibcast_fn_t previous_ibcast; - mca_coll_base_module_t* previous_ibcast_module; - mca_coll_base_module_alltoall_fn_t previous_alltoall; - mca_coll_base_module_t* previous_alltoall_module; - mca_coll_base_module_ialltoall_fn_t previous_ialltoall; - mca_coll_base_module_t* previous_ialltoall_module; - mca_coll_base_module_alltoallv_fn_t previous_alltoallv; - mca_coll_base_module_t* previous_alltoallv_module; - mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv; - mca_coll_base_module_t* previous_ialltoallv_module; - mca_coll_base_module_allgather_fn_t previous_allgather; - mca_coll_base_module_t* previous_allgather_module; - mca_coll_base_module_iallgather_fn_t previous_iallgather; - mca_coll_base_module_t* previous_iallgather_module; - mca_coll_base_module_allgatherv_fn_t previous_allgatherv; - mca_coll_base_module_t* previous_allgatherv_module; - mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv; - mca_coll_base_module_t* previous_iallgatherv_module; + mca_coll_base_module_t super; + ompi_communicator_t* comm; + int rank; + ucc_team_h ucc_team; + mca_coll_base_module_allreduce_fn_t previous_allreduce; + mca_coll_base_module_t* previous_allreduce_module; + mca_coll_base_module_iallreduce_fn_t previous_iallreduce; + mca_coll_base_module_t* previous_iallreduce_module; + mca_coll_base_module_reduce_fn_t previous_reduce; + mca_coll_base_module_t* previous_reduce_module; + mca_coll_base_module_ireduce_fn_t previous_ireduce; + mca_coll_base_module_t* previous_ireduce_module; + mca_coll_base_module_barrier_fn_t previous_barrier; + mca_coll_base_module_t* previous_barrier_module; + mca_coll_base_module_ibarrier_fn_t previous_ibarrier; + mca_coll_base_module_t* previous_ibarrier_module; + mca_coll_base_module_bcast_fn_t previous_bcast; + mca_coll_base_module_t* previous_bcast_module; + mca_coll_base_module_ibcast_fn_t previous_ibcast; + mca_coll_base_module_t* previous_ibcast_module; + mca_coll_base_module_alltoall_fn_t previous_alltoall; + mca_coll_base_module_t* previous_alltoall_module; + mca_coll_base_module_ialltoall_fn_t previous_ialltoall; + mca_coll_base_module_t* previous_ialltoall_module; + mca_coll_base_module_alltoallv_fn_t previous_alltoallv; + mca_coll_base_module_t* previous_alltoallv_module; + mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv; + mca_coll_base_module_t* previous_ialltoallv_module; + mca_coll_base_module_allgather_fn_t previous_allgather; + mca_coll_base_module_t* previous_allgather_module; + mca_coll_base_module_iallgather_fn_t previous_iallgather; + mca_coll_base_module_t* previous_iallgather_module; + mca_coll_base_module_allgatherv_fn_t previous_allgatherv; + mca_coll_base_module_t* previous_allgatherv_module; + mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv; + mca_coll_base_module_t* previous_iallgatherv_module; + mca_coll_base_module_reduce_scatter_block_fn_t previous_reduce_scatter_block; + mca_coll_base_module_t* previous_reduce_scatter_block_module; + mca_coll_base_module_ireduce_scatter_block_fn_t previous_ireduce_scatter_block; + mca_coll_base_module_t* previous_ireduce_scatter_block_module; }; typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t; OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t); @@ -195,5 +200,18 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, int scount, struct ompi_datatype_ ompi_request_t** request, mca_coll_base_module_t *module); +int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module); + +int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + ompi_request_t** request, + mca_coll_base_module_t *module); + END_C_DECLS #endif diff --git a/ompi/mca/coll/ucc/coll_ucc_component.c b/ompi/mca/coll/ucc/coll_ucc_component.c index a70c6fb310e..a42341ca769 100644 --- a/ompi/mca/coll/ucc/coll_ucc_component.c +++ b/ompi/mca/coll/ucc/coll_ucc_component.c @@ -120,6 +120,8 @@ static ucc_coll_type_t mca_coll_ucc_str_to_type(const char *str) return UCC_COLL_TYPE_ALLGATHERV; } else if (0 == strcasecmp(str, "reduce")) { return UCC_COLL_TYPE_REDUCE; + } else if (0 == strcasecmp(str, "reduce_scatter_block")) { + return UCC_COLL_TYPE_REDUCE_SCATTER; } UCC_ERROR("incorrect value for cts: %s, allowed: %s", str, COLL_UCC_CTS_STR); diff --git a/ompi/mca/coll/ucc/coll_ucc_module.c b/ompi/mca/coll/ucc/coll_ucc_module.c index 16866976185..f46f83c9d05 100644 --- a/ompi/mca/coll/ucc/coll_ucc_module.c +++ b/ompi/mca/coll/ucc/coll_ucc_module.c @@ -29,23 +29,27 @@ int mca_coll_ucc_init_query(bool enable_progress_threads, bool enable_mpi_thread static void mca_coll_ucc_module_clear(mca_coll_ucc_module_t *ucc_module) { - ucc_module->ucc_team = NULL; - ucc_module->previous_allreduce = NULL; - ucc_module->previous_iallreduce = NULL; - ucc_module->previous_barrier = NULL; - ucc_module->previous_ibarrier = NULL; - ucc_module->previous_bcast = NULL; - ucc_module->previous_ibcast = NULL; - ucc_module->previous_alltoall = NULL; - ucc_module->previous_ialltoall = NULL; - ucc_module->previous_alltoallv = NULL; - ucc_module->previous_ialltoallv = NULL; - ucc_module->previous_allgather = NULL; - ucc_module->previous_iallgather = NULL; - ucc_module->previous_allgatherv = NULL; - ucc_module->previous_iallgatherv = NULL; - ucc_module->previous_reduce = NULL; - ucc_module->previous_ireduce = NULL; + ucc_module->ucc_team = NULL; + ucc_module->previous_allreduce = NULL; + ucc_module->previous_iallreduce = NULL; + ucc_module->previous_barrier = NULL; + ucc_module->previous_ibarrier = NULL; + ucc_module->previous_bcast = NULL; + ucc_module->previous_ibcast = NULL; + ucc_module->previous_alltoall = NULL; + ucc_module->previous_ialltoall = NULL; + ucc_module->previous_alltoallv = NULL; + ucc_module->previous_ialltoallv = NULL; + ucc_module->previous_allgather = NULL; + ucc_module->previous_iallgather = NULL; + ucc_module->previous_allgatherv = NULL; + ucc_module->previous_iallgatherv = NULL; + ucc_module->previous_reduce = NULL; + ucc_module->previous_ireduce = NULL; + ucc_module->previous_reduce_scatter_block = NULL; + ucc_module->previous_reduce_scatter_block_module = NULL; + ucc_module->previous_ireduce_scatter_block = NULL; + ucc_module->previous_ireduce_scatter_block_module = NULL; } static void mca_coll_ucc_module_construct(mca_coll_ucc_module_t *ucc_module) @@ -82,6 +86,8 @@ static void mca_coll_ucc_module_destruct(mca_coll_ucc_module_t *ucc_module) OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallgatherv_module); OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_module); OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_module); + OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_scatter_block_module); + OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_scatter_block_module); mca_coll_ucc_module_clear(ucc_module); } @@ -113,6 +119,8 @@ static int mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module) SAVE_PREV_COLL_API(iallgatherv); SAVE_PREV_COLL_API(reduce); SAVE_PREV_COLL_API(ireduce); + SAVE_PREV_COLL_API(reduce_scatter_block); + SAVE_PREV_COLL_API(ireduce_scatter_block); return OMPI_SUCCESS; } @@ -491,14 +499,15 @@ mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority) ucc_module->comm = comm; ucc_module->super.coll_module_enable = mca_coll_ucc_module_enable; *priority = cm->ucc_priority; - SET_COLL_PTR(ucc_module, BARRIER, barrier); - SET_COLL_PTR(ucc_module, BCAST, bcast); - SET_COLL_PTR(ucc_module, ALLREDUCE, allreduce); - SET_COLL_PTR(ucc_module, ALLTOALL, alltoall); - SET_COLL_PTR(ucc_module, ALLTOALLV, alltoallv); - SET_COLL_PTR(ucc_module, REDUCE, reduce); - SET_COLL_PTR(ucc_module, ALLGATHER, allgather); - SET_COLL_PTR(ucc_module, ALLGATHERV, allgatherv); + SET_COLL_PTR(ucc_module, BARRIER, barrier); + SET_COLL_PTR(ucc_module, BCAST, bcast); + SET_COLL_PTR(ucc_module, ALLREDUCE, allreduce); + SET_COLL_PTR(ucc_module, ALLTOALL, alltoall); + SET_COLL_PTR(ucc_module, ALLTOALLV, alltoallv); + SET_COLL_PTR(ucc_module, REDUCE, reduce); + SET_COLL_PTR(ucc_module, ALLGATHER, allgather); + SET_COLL_PTR(ucc_module, ALLGATHERV, allgatherv); + SET_COLL_PTR(ucc_module, REDUCE_SCATTER, reduce_scatter_block); return &ucc_module->super; } diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c new file mode 100644 index 00000000000..e12f472733e --- /dev/null +++ b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c @@ -0,0 +1,117 @@ +/** + * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + */ + +#include "coll_ucc_common.h" + +static inline +ucc_status_t mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf, + size_t rcount, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) +{ + ucc_datatype_t ucc_dt; + ucc_reduction_op_t ucc_op; + int comm_size = ompi_comm_size(ucc_module->comm); + + if (MPI_IN_PLACE == sbuf) { + /* TODO: UCC defines inplace differently: + data in rbuf of rank R is shifted by R * rcount */ + UCC_VERBOSE(5, "inplace reduce_scatter_block is not supported"); + return UCC_ERR_NOT_SUPPORTED; + } + ucc_dt = ompi_dtype_to_ucc_dtype(dtype); + ucc_op = ompi_op_to_ucc_op(op); + if (OPAL_UNLIKELY(COLL_UCC_DT_UNSUPPORTED == ucc_dt)) { + UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", + dtype->super.name); + goto fallback; + } + if (OPAL_UNLIKELY(COLL_UCC_OP_UNSUPPORTED == ucc_op)) { + UCC_VERBOSE(5, "ompi_op is not supported: op = %s", + op->o_name); + goto fallback; + } + ucc_coll_args_t coll = { + .mask = 0, + .flags = 0, + .coll_type = UCC_COLL_TYPE_REDUCE_SCATTER, + .src.info = { + .buffer = (void*)sbuf, + .count = ((size_t)rcount) * comm_size, + .datatype = ucc_dt, + .mem_type = UCC_MEMORY_TYPE_UNKNOWN + }, + .dst.info = { + .buffer = rbuf, + .count = rcount, + .datatype = ucc_dt, + .mem_type = UCC_MEMORY_TYPE_UNKNOWN + }, + .op = ucc_op, + }; + COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); + return UCC_OK; +fallback: + return UCC_ERR_NOT_SUPPORTED; +} + +int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module; + ucc_coll_req_h req; + + UCC_VERBOSE(3, "running ucc reduce scatter block"); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init(sbuf, rbuf, rcount, + dtype, op, ucc_module, + &req, NULL)); + COLL_UCC_POST_AND_CHECK(req); + COLL_UCC_CHECK(coll_ucc_req_wait(req)); + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback reduce_scatter_block"); + return ucc_module->previous_reduce_scatter_block(sbuf, rbuf, rcount, dtype, + op, comm, + ucc_module->previous_reduce_scatter_block_module); +} + +int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + ompi_request_t** request, + mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + UCC_VERBOSE(3, "running ucc ireduce_scatter_block"); + COLL_UCC_GET_REQ(coll_req); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init(sbuf, rbuf, rcount, + dtype, op, ucc_module, + &req, coll_req)); + COLL_UCC_POST_AND_CHECK(req); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback ireduce_scatter_block"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **)&coll_req); + } + return ucc_module->previous_ireduce_scatter_block(sbuf, rbuf, rcount, dtype, + op, comm, request, + ucc_module->previous_ireduce_scatter_block_module); +}