Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions ompi/mca/coll/ucc/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@

AM_CPPFLAGS = $(coll_ucc_CPPFLAGS)

coll_ucc_sources = \
coll_ucc.h \
coll_ucc_debug.h \
coll_ucc_dtypes.h \
coll_ucc_common.h \
coll_ucc_module.c \
coll_ucc_component.c \
coll_ucc_barrier.c \
coll_ucc_bcast.c \
coll_ucc_allreduce.c \
coll_ucc_reduce.c \
coll_ucc_alltoall.c \
coll_ucc_alltoallv.c \
coll_ucc_allgather.c \
coll_ucc_allgatherv.c
coll_ucc_sources = \
coll_ucc.h \
coll_ucc_debug.h \
coll_ucc_dtypes.h \
coll_ucc_common.h \
coll_ucc_module.c \
coll_ucc_component.c \
coll_ucc_barrier.c \
coll_ucc_bcast.c \
coll_ucc_allreduce.c \
coll_ucc_reduce.c \
coll_ucc_alltoall.c \
coll_ucc_alltoallv.c \
coll_ucc_allgather.c \
coll_ucc_allgatherv.c \
coll_ucc_reduce_scatter_block.c

# Make the output library in this directory, and name it either
# mca_<type>_<name>.la (for DSO builds) or libmca_<type>_<name>.la
Expand Down
96 changes: 57 additions & 39 deletions ompi/mca/coll/ucc/coll_ucc.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ BEGIN_C_DECLS
#define COLL_UCC_CTS (UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST | \
UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLTOALL | \
UCC_COLL_TYPE_ALLTOALLV | UCC_COLL_TYPE_ALLGATHER | \
UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV)
UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV | \
UCC_COLL_TYPE_REDUCE_SCATTER)

#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather,allgatherv,reduce," \
"ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,iallgatherv,ireduce"
#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather,allgatherv,reduce,reduce_scatter_block," \
"ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,iallgatherv,ireduce,ireduce_scatter_block"

typedef struct mca_coll_ucc_req {
ompi_request_t super;
Expand Down Expand Up @@ -64,42 +65,46 @@ OMPI_MODULE_DECLSPEC extern mca_coll_ucc_component_t mca_coll_ucc_component;
* UCC enabled communicator
*/
struct mca_coll_ucc_module_t {
mca_coll_base_module_t super;
ompi_communicator_t* comm;
int rank;
ucc_team_h ucc_team;
mca_coll_base_module_allreduce_fn_t previous_allreduce;
mca_coll_base_module_t* previous_allreduce_module;
mca_coll_base_module_iallreduce_fn_t previous_iallreduce;
mca_coll_base_module_t* previous_iallreduce_module;
mca_coll_base_module_reduce_fn_t previous_reduce;
mca_coll_base_module_t* previous_reduce_module;
mca_coll_base_module_ireduce_fn_t previous_ireduce;
mca_coll_base_module_t* previous_ireduce_module;
mca_coll_base_module_barrier_fn_t previous_barrier;
mca_coll_base_module_t* previous_barrier_module;
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
mca_coll_base_module_t* previous_ibarrier_module;
mca_coll_base_module_bcast_fn_t previous_bcast;
mca_coll_base_module_t* previous_bcast_module;
mca_coll_base_module_ibcast_fn_t previous_ibcast;
mca_coll_base_module_t* previous_ibcast_module;
mca_coll_base_module_alltoall_fn_t previous_alltoall;
mca_coll_base_module_t* previous_alltoall_module;
mca_coll_base_module_ialltoall_fn_t previous_ialltoall;
mca_coll_base_module_t* previous_ialltoall_module;
mca_coll_base_module_alltoallv_fn_t previous_alltoallv;
mca_coll_base_module_t* previous_alltoallv_module;
mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv;
mca_coll_base_module_t* previous_ialltoallv_module;
mca_coll_base_module_allgather_fn_t previous_allgather;
mca_coll_base_module_t* previous_allgather_module;
mca_coll_base_module_iallgather_fn_t previous_iallgather;
mca_coll_base_module_t* previous_iallgather_module;
mca_coll_base_module_allgatherv_fn_t previous_allgatherv;
mca_coll_base_module_t* previous_allgatherv_module;
mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv;
mca_coll_base_module_t* previous_iallgatherv_module;
mca_coll_base_module_t super;
ompi_communicator_t* comm;
int rank;
ucc_team_h ucc_team;
mca_coll_base_module_allreduce_fn_t previous_allreduce;
mca_coll_base_module_t* previous_allreduce_module;
mca_coll_base_module_iallreduce_fn_t previous_iallreduce;
mca_coll_base_module_t* previous_iallreduce_module;
mca_coll_base_module_reduce_fn_t previous_reduce;
mca_coll_base_module_t* previous_reduce_module;
mca_coll_base_module_ireduce_fn_t previous_ireduce;
mca_coll_base_module_t* previous_ireduce_module;
mca_coll_base_module_barrier_fn_t previous_barrier;
mca_coll_base_module_t* previous_barrier_module;
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
mca_coll_base_module_t* previous_ibarrier_module;
mca_coll_base_module_bcast_fn_t previous_bcast;
mca_coll_base_module_t* previous_bcast_module;
mca_coll_base_module_ibcast_fn_t previous_ibcast;
mca_coll_base_module_t* previous_ibcast_module;
mca_coll_base_module_alltoall_fn_t previous_alltoall;
mca_coll_base_module_t* previous_alltoall_module;
mca_coll_base_module_ialltoall_fn_t previous_ialltoall;
mca_coll_base_module_t* previous_ialltoall_module;
mca_coll_base_module_alltoallv_fn_t previous_alltoallv;
mca_coll_base_module_t* previous_alltoallv_module;
mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv;
mca_coll_base_module_t* previous_ialltoallv_module;
mca_coll_base_module_allgather_fn_t previous_allgather;
mca_coll_base_module_t* previous_allgather_module;
mca_coll_base_module_iallgather_fn_t previous_iallgather;
mca_coll_base_module_t* previous_iallgather_module;
mca_coll_base_module_allgatherv_fn_t previous_allgatherv;
mca_coll_base_module_t* previous_allgatherv_module;
mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv;
mca_coll_base_module_t* previous_iallgatherv_module;
mca_coll_base_module_reduce_scatter_block_fn_t previous_reduce_scatter_block;
mca_coll_base_module_t* previous_reduce_scatter_block_module;
mca_coll_base_module_ireduce_scatter_block_fn_t previous_ireduce_scatter_block;
mca_coll_base_module_t* previous_ireduce_scatter_block_module;
};
typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t;
OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t);
Expand Down Expand Up @@ -195,5 +200,18 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, int scount, struct ompi_datatype_
ompi_request_t** request,
mca_coll_base_module_t *module);

int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);

int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
ompi_request_t** request,
mca_coll_base_module_t *module);

END_C_DECLS
#endif
2 changes: 2 additions & 0 deletions ompi/mca/coll/ucc/coll_ucc_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ static ucc_coll_type_t mca_coll_ucc_str_to_type(const char *str)
return UCC_COLL_TYPE_ALLGATHERV;
} else if (0 == strcasecmp(str, "reduce")) {
return UCC_COLL_TYPE_REDUCE;
} else if (0 == strcasecmp(str, "reduce_scatter_block")) {
return UCC_COLL_TYPE_REDUCE_SCATTER;
}
UCC_ERROR("incorrect value for cts: %s, allowed: %s",
str, COLL_UCC_CTS_STR);
Expand Down
59 changes: 34 additions & 25 deletions ompi/mca/coll/ucc/coll_ucc_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,27 @@ int mca_coll_ucc_init_query(bool enable_progress_threads, bool enable_mpi_thread

static void mca_coll_ucc_module_clear(mca_coll_ucc_module_t *ucc_module)
{
ucc_module->ucc_team = NULL;
ucc_module->previous_allreduce = NULL;
ucc_module->previous_iallreduce = NULL;
ucc_module->previous_barrier = NULL;
ucc_module->previous_ibarrier = NULL;
ucc_module->previous_bcast = NULL;
ucc_module->previous_ibcast = NULL;
ucc_module->previous_alltoall = NULL;
ucc_module->previous_ialltoall = NULL;
ucc_module->previous_alltoallv = NULL;
ucc_module->previous_ialltoallv = NULL;
ucc_module->previous_allgather = NULL;
ucc_module->previous_iallgather = NULL;
ucc_module->previous_allgatherv = NULL;
ucc_module->previous_iallgatherv = NULL;
ucc_module->previous_reduce = NULL;
ucc_module->previous_ireduce = NULL;
ucc_module->ucc_team = NULL;
ucc_module->previous_allreduce = NULL;
ucc_module->previous_iallreduce = NULL;
ucc_module->previous_barrier = NULL;
ucc_module->previous_ibarrier = NULL;
ucc_module->previous_bcast = NULL;
ucc_module->previous_ibcast = NULL;
ucc_module->previous_alltoall = NULL;
ucc_module->previous_ialltoall = NULL;
ucc_module->previous_alltoallv = NULL;
ucc_module->previous_ialltoallv = NULL;
ucc_module->previous_allgather = NULL;
ucc_module->previous_iallgather = NULL;
ucc_module->previous_allgatherv = NULL;
ucc_module->previous_iallgatherv = NULL;
ucc_module->previous_reduce = NULL;
ucc_module->previous_ireduce = NULL;
ucc_module->previous_reduce_scatter_block = NULL;
ucc_module->previous_reduce_scatter_block_module = NULL;
ucc_module->previous_ireduce_scatter_block = NULL;
ucc_module->previous_ireduce_scatter_block_module = NULL;
}

static void mca_coll_ucc_module_construct(mca_coll_ucc_module_t *ucc_module)
Expand Down Expand Up @@ -82,6 +86,8 @@ static void mca_coll_ucc_module_destruct(mca_coll_ucc_module_t *ucc_module)
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallgatherv_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_scatter_block_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_scatter_block_module);
mca_coll_ucc_module_clear(ucc_module);
}

Expand Down Expand Up @@ -113,6 +119,8 @@ static int mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module)
SAVE_PREV_COLL_API(iallgatherv);
SAVE_PREV_COLL_API(reduce);
SAVE_PREV_COLL_API(ireduce);
SAVE_PREV_COLL_API(reduce_scatter_block);
SAVE_PREV_COLL_API(ireduce_scatter_block);
return OMPI_SUCCESS;
}

Expand Down Expand Up @@ -491,14 +499,15 @@ mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority)
ucc_module->comm = comm;
ucc_module->super.coll_module_enable = mca_coll_ucc_module_enable;
*priority = cm->ucc_priority;
SET_COLL_PTR(ucc_module, BARRIER, barrier);
SET_COLL_PTR(ucc_module, BCAST, bcast);
SET_COLL_PTR(ucc_module, ALLREDUCE, allreduce);
SET_COLL_PTR(ucc_module, ALLTOALL, alltoall);
SET_COLL_PTR(ucc_module, ALLTOALLV, alltoallv);
SET_COLL_PTR(ucc_module, REDUCE, reduce);
SET_COLL_PTR(ucc_module, ALLGATHER, allgather);
SET_COLL_PTR(ucc_module, ALLGATHERV, allgatherv);
SET_COLL_PTR(ucc_module, BARRIER, barrier);
SET_COLL_PTR(ucc_module, BCAST, bcast);
SET_COLL_PTR(ucc_module, ALLREDUCE, allreduce);
SET_COLL_PTR(ucc_module, ALLTOALL, alltoall);
SET_COLL_PTR(ucc_module, ALLTOALLV, alltoallv);
SET_COLL_PTR(ucc_module, REDUCE, reduce);
SET_COLL_PTR(ucc_module, ALLGATHER, allgather);
SET_COLL_PTR(ucc_module, ALLGATHERV, allgatherv);
SET_COLL_PTR(ucc_module, REDUCE_SCATTER, reduce_scatter_block);
return &ucc_module->super;
}

Expand Down
117 changes: 117 additions & 0 deletions ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/**
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
* Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
*
*/

#include "coll_ucc_common.h"

static inline
ucc_status_t mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf,
size_t rcount,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
{
ucc_datatype_t ucc_dt;
ucc_reduction_op_t ucc_op;
int comm_size = ompi_comm_size(ucc_module->comm);

if (MPI_IN_PLACE == sbuf) {
/* TODO: UCC defines inplace differently:
data in rbuf of rank R is shifted by R * rcount */
UCC_VERBOSE(5, "inplace reduce_scatter_block is not supported");
return UCC_ERR_NOT_SUPPORTED;
}
ucc_dt = ompi_dtype_to_ucc_dtype(dtype);
ucc_op = ompi_op_to_ucc_op(op);
if (OPAL_UNLIKELY(COLL_UCC_DT_UNSUPPORTED == ucc_dt)) {
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
dtype->super.name);
goto fallback;
}
if (OPAL_UNLIKELY(COLL_UCC_OP_UNSUPPORTED == ucc_op)) {
UCC_VERBOSE(5, "ompi_op is not supported: op = %s",
op->o_name);
goto fallback;
}
ucc_coll_args_t coll = {
.mask = 0,
.flags = 0,
.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER,
.src.info = {
.buffer = (void*)sbuf,
.count = ((size_t)rcount) * comm_size,
.datatype = ucc_dt,
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
},
.dst.info = {
.buffer = rbuf,
.count = rcount,
.datatype = ucc_dt,
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
},
.op = ucc_op,
};
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
return UCC_OK;
fallback:
return UCC_ERR_NOT_SUPPORTED;
}

int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
{
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
ucc_coll_req_h req;

UCC_VERBOSE(3, "running ucc reduce scatter block");
COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init(sbuf, rbuf, rcount,
dtype, op, ucc_module,
&req, NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback reduce_scatter_block");
return ucc_module->previous_reduce_scatter_block(sbuf, rbuf, rcount, dtype,
op, comm,
ucc_module->previous_reduce_scatter_block_module);
}

int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
ompi_request_t** request,
mca_coll_base_module_t *module)
{
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
ucc_coll_req_h req;
mca_coll_ucc_req_t *coll_req = NULL;

UCC_VERBOSE(3, "running ucc ireduce_scatter_block");
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init(sbuf, rbuf, rcount,
dtype, op, ucc_module,
&req, coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback ireduce_scatter_block");
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ireduce_scatter_block(sbuf, rbuf, rcount, dtype,
op, comm, request,
ucc_module->previous_ireduce_scatter_block_module);
}
Loading