diff --git a/ompi/mca/coll/accelerator/Makefile.am b/ompi/mca/coll/accelerator/Makefile.am index eaf81137602..e3621c1d05a 100644 --- a/ompi/mca/coll/accelerator/Makefile.am +++ b/ompi/mca/coll/accelerator/Makefile.am @@ -2,7 +2,7 @@ # Copyright (c) 2014 The University of Tennessee and The University # of Tennessee Research Foundation. All rights # reserved. -# Copyright (c) 2014 NVIDIA Corporation. All rights reserved. +# Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved. # Copyright (c) 2017 IBM Corporation. All rights reserved. # $COPYRIGHT$ # @@ -10,7 +10,6 @@ # # $HEADER$ # -dist_ompidata_DATA = help-mpi-coll-accelerator.txt sources = coll_accelerator_module.c coll_accelerator_reduce.c coll_accelerator_allreduce.c \ coll_accelerator_reduce_scatter_block.c coll_accelerator_component.c \ diff --git a/ompi/mca/coll/accelerator/coll_accelerator.h b/ompi/mca/coll/accelerator/coll_accelerator.h index c840a3c2d27..5e8eb64794f 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator.h +++ b/ompi/mca/coll/accelerator/coll_accelerator.h @@ -2,7 +2,7 @@ * Copyright (c) 2014 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. - * Copyright (c) 2014-2015 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved. * Copyright (c) 2024 Triad National Security, LLC. All rights reserved. * $COPYRIGHT$ * @@ -38,9 +38,6 @@ mca_coll_base_module_t *mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm, int *priority); -int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - int mca_coll_accelerator_allreduce(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, diff --git a/ompi/mca/coll/accelerator/coll_accelerator_module.c b/ompi/mca/coll/accelerator/coll_accelerator_module.c index 505d31c1c07..4fe1603a8aa 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator_module.c +++ b/ompi/mca/coll/accelerator/coll_accelerator_module.c @@ -2,7 +2,7 @@ * Copyright (c) 2014-2017 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. - * Copyright (c) 2014 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved. * Copyright (c) 2019 Research Organization for Information Science * and Technology (RIST). All rights reserved. * Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. @@ -32,30 +32,21 @@ #include "ompi/mca/coll/base/base.h" #include "coll_accelerator.h" +static int +mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); +static int +mca_coll_accelerator_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); static void mca_coll_accelerator_module_construct(mca_coll_accelerator_module_t *module) { memset(&(module->c_coll), 0, sizeof(module->c_coll)); } -static void mca_coll_accelerator_module_destruct(mca_coll_accelerator_module_t *module) -{ - OBJ_RELEASE(module->c_coll.coll_allreduce_module); - OBJ_RELEASE(module->c_coll.coll_reduce_module); - OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module); - OBJ_RELEASE(module->c_coll.coll_scatter_module); - /* If the exscan module is not NULL, then this was an - intracommunicator, and therefore scan will have a module as - well. */ - if (NULL != module->c_coll.coll_exscan_module) { - OBJ_RELEASE(module->c_coll.coll_exscan_module); - OBJ_RELEASE(module->c_coll.coll_scan_module); - } -} - OBJ_CLASS_INSTANCE(mca_coll_accelerator_module_t, mca_coll_base_module_t, mca_coll_accelerator_module_construct, - mca_coll_accelerator_module_destruct); + NULL); /* @@ -99,66 +90,82 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm, /* Choose whether to use [intra|inter] */ accelerator_module->super.coll_module_enable = mca_coll_accelerator_module_enable; + accelerator_module->super.coll_module_disable = mca_coll_accelerator_module_disable; - accelerator_module->super.coll_allgather = NULL; - accelerator_module->super.coll_allgatherv = NULL; accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce; - accelerator_module->super.coll_alltoall = NULL; - accelerator_module->super.coll_alltoallv = NULL; - accelerator_module->super.coll_alltoallw = NULL; - accelerator_module->super.coll_barrier = NULL; - accelerator_module->super.coll_bcast = NULL; - accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan; - accelerator_module->super.coll_gather = NULL; - accelerator_module->super.coll_gatherv = NULL; accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce; - accelerator_module->super.coll_reduce_scatter = NULL; accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block; - accelerator_module->super.coll_scan = mca_coll_accelerator_scan; - accelerator_module->super.coll_scatter = NULL; - accelerator_module->super.coll_scatterv = NULL; + if (!OMPI_COMM_IS_INTER(comm)) { + accelerator_module->super.coll_scan = mca_coll_accelerator_scan; + accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan; + } return &(accelerator_module->super); } +#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if ((__comm)->c_coll->coll_##__api) \ + { \ + MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \ + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \ + } \ + else \ + { \ + opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \ + "cuda", #__api, ompi_process_info.nodename, \ + mca_coll_accelerator_component.priority); \ + } \ + } while (0) + +#define ACCELERATOR_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (&(__module)->super == (__comm)->c_coll->coll_##__api##_module) { \ + MCA_COLL_INSTALL_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \ + (__module)->c_coll.coll_##__api##_module = NULL; \ + (__module)->c_coll.coll_##__api = NULL; \ + } \ + } while (0) + /* - * Init module on the communicator + * Init/Fini module on the communicator */ -int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm) +static int +mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) { - bool good = true; - char *msg = NULL; mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module; -#define CHECK_AND_RETAIN(src, dst, name) \ - if (NULL == (src)->c_coll->coll_ ## name ## _module) { \ - good = false; \ - msg = #name; \ - } else if (good) { \ - (dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module; \ - (dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name; \ - OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module); \ - } - - CHECK_AND_RETAIN(comm, s, allreduce); - CHECK_AND_RETAIN(comm, s, reduce); - CHECK_AND_RETAIN(comm, s, reduce_scatter_block); - CHECK_AND_RETAIN(comm, s, scatter); + ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce); + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce); + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block); if (!OMPI_COMM_IS_INTER(comm)) { /* MPI does not define scan/exscan on intercommunicators */ - CHECK_AND_RETAIN(comm, s, exscan); - CHECK_AND_RETAIN(comm, s, scan); + ACCELERATOR_INSTALL_COLL_API(comm, s, exscan); + ACCELERATOR_INSTALL_COLL_API(comm, s, scan); } - /* All done */ - if (good) { - return OMPI_SUCCESS; - } - opal_show_help("help-mpi-coll-accelerator.txt", "missing collective", true, - ompi_process_info.nodename, - mca_coll_accelerator_component.priority, msg); - return OMPI_ERR_NOT_FOUND; + return OMPI_SUCCESS; } +static int +mca_coll_accelerator_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module; + + ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce); + ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce); + ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block); + if (!OMPI_COMM_IS_INTER(comm)) + { + /* MPI does not define scan/exscan on intercommunicators */ + ACCELERATOR_UNINSTALL_COLL_API(comm, s, exscan); + ACCELERATOR_UNINSTALL_COLL_API(comm, s, scan); + } + + return OMPI_SUCCESS; +} diff --git a/ompi/mca/coll/accelerator/help-mpi-coll-accelerator.txt b/ompi/mca/coll/accelerator/help-mpi-coll-accelerator.txt deleted file mode 100644 index abc39d08e12..00000000000 --- a/ompi/mca/coll/accelerator/help-mpi-coll-accelerator.txt +++ /dev/null @@ -1,29 +0,0 @@ -# -*- text -*- -# -# Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana -# University Research and Technology -# Corporation. All rights reserved. -# Copyright (c) 2004-2005 The University of Tennessee and The University -# of Tennessee Research Foundation. All rights -# reserved. -# Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, -# University of Stuttgart. All rights reserved. -# Copyright (c) 2004-2005 The Regents of the University of California. -# All rights reserved. -# Copyright (c) 2014 NVIDIA Corporation. All rights reserved. -# Copyright (c) 2024 Triad National Security, LLC. All rights reserved. -# $COPYRIGHT$ -# -# Additional copyrights may follow -# -# $HEADER$ -# -# This is the US/English general help file for Open MPI's accelerator -# collective reduction component. -# -[missing collective] -There was a problem while initializing support for the accelerator reduction operations. -hostname: %s -priority: %d -collective: %s -# diff --git a/ompi/mca/coll/adapt/coll_adapt_module.c b/ompi/mca/coll/adapt/coll_adapt_module.c index 12e79d72d4a..abd27d09870 100644 --- a/ompi/mca/coll/adapt/coll_adapt_module.c +++ b/ompi/mca/coll/adapt/coll_adapt_module.c @@ -5,6 +5,7 @@ * Copyright (c) 2021 Triad National Security, LLC. All rights * reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * * $COPYRIGHT$ * @@ -83,25 +84,41 @@ OBJ_CLASS_INSTANCE(mca_coll_adapt_module_t, adapt_module_construct, adapt_module_destruct); -/* - * In this macro, the following variables are supposed to have been declared - * in the caller: - * . ompi_communicator_t *comm - * . mca_coll_adapt_module_t *adapt_module - */ -#define ADAPT_SAVE_PREV_COLL_API(__api) \ - do { \ - adapt_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \ - adapt_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \ - if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \ - opal_output_verbose(1, ompi_coll_base_framework.framework_output, \ - "(%s/%s): no underlying " # __api"; disqualifying myself", \ - ompi_comm_print_cid(comm), comm->c_name); \ - return OMPI_ERROR; \ - } \ - OBJ_RETAIN(adapt_module->previous_ ## __api ## _module); \ - } while(0) - +#define ADAPT_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__module->super.coll_##__api) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \ + } \ + } while (0) +#define ADAPT_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "adapt"); \ + } \ + } while (0) +#define ADAPT_INSTALL_AND_SAVE_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api && __comm->c_coll->coll_##__api##_module) \ + { \ + MCA_COLL_SAVE_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \ + } \ + } while (0) +#define ADAPT_UNINSTALL_AND_RESTORE_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \ + __module->previous_##__api = NULL; \ + __module->previous_##__api##_module = NULL; \ + } \ + } while (0) /* * Init module on the communicator @@ -111,12 +128,25 @@ static int adapt_module_enable(mca_coll_base_module_t * module, { mca_coll_adapt_module_t * adapt_module = (mca_coll_adapt_module_t*) module; - ADAPT_SAVE_PREV_COLL_API(reduce); - ADAPT_SAVE_PREV_COLL_API(ireduce); + ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, reduce); + ADAPT_INSTALL_COLL_API(comm, adapt_module, bcast); + ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, ireduce); + ADAPT_INSTALL_COLL_API(comm, adapt_module, ibcast); return OMPI_SUCCESS; } +static int adapt_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_adapt_module_t *adapt_module = (mca_coll_adapt_module_t *)module; + ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, reduce); + ADAPT_UNINSTALL_COLL_API(comm, adapt_module, bcast); + ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, ireduce); + ADAPT_UNINSTALL_COLL_API(comm, adapt_module, ibcast); + + return OMPI_SUCCESS; +} /* * Initial query function that is invoked during MPI_INIT, allowing * this component to disqualify itself if it doesn't support the @@ -165,24 +195,11 @@ mca_coll_base_module_t *ompi_coll_adapt_comm_query(struct ompi_communicator_t * /* All is good -- return a module */ adapt_module->super.coll_module_enable = adapt_module_enable; - adapt_module->super.coll_allgather = NULL; - adapt_module->super.coll_allgatherv = NULL; - adapt_module->super.coll_allreduce = NULL; - adapt_module->super.coll_alltoall = NULL; - adapt_module->super.coll_alltoallw = NULL; - adapt_module->super.coll_barrier = NULL; + adapt_module->super.coll_module_disable = adapt_module_disable; adapt_module->super.coll_bcast = ompi_coll_adapt_bcast; - adapt_module->super.coll_exscan = NULL; - adapt_module->super.coll_gather = NULL; - adapt_module->super.coll_gatherv = NULL; adapt_module->super.coll_reduce = ompi_coll_adapt_reduce; - adapt_module->super.coll_reduce_scatter = NULL; - adapt_module->super.coll_scan = NULL; - adapt_module->super.coll_scatter = NULL; - adapt_module->super.coll_scatterv = NULL; adapt_module->super.coll_ibcast = ompi_coll_adapt_ibcast; adapt_module->super.coll_ireduce = ompi_coll_adapt_ireduce; - adapt_module->super.coll_iallreduce = NULL; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "coll:adapt:comm_query (%s/%s): pick me! pick me!", diff --git a/ompi/mca/coll/base/coll_base_comm_select.c b/ompi/mca/coll/base/coll_base_comm_select.c index 83175aedc7d..37f7b53d84a 100644 --- a/ompi/mca/coll/base/coll_base_comm_select.c +++ b/ompi/mca/coll/base/coll_base_comm_select.c @@ -22,6 +22,7 @@ * Copyright (c) 2016-2017 IBM Corporation. All rights reserved. * Copyright (c) 2017 FUJITSU LIMITED. All rights reserved. * Copyright (c) 2020 BULL S.A.S. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -71,18 +72,6 @@ static int query_2_4_0(const mca_coll_base_component_2_4_0_t * int *priority, mca_coll_base_module_t ** module); -#define COPY(module, comm, func) \ - do { \ - if (NULL != module->coll_ ## func) { \ - if (NULL != comm->c_coll->coll_ ## func ## _module) { \ - OBJ_RELEASE(comm->c_coll->coll_ ## func ## _module); \ - } \ - comm->c_coll->coll_ ## func = module->coll_ ## func; \ - comm->c_coll->coll_ ## func ## _module = module; \ - OBJ_RETAIN(module); \ - } \ - } while (0) - #define CHECK_NULL(what, comm, func) \ ( (what) = # func , NULL == (comm)->c_coll->coll_ ## func) @@ -147,88 +136,6 @@ int mca_coll_base_comm_select(ompi_communicator_t * comm) /* Save every component that is initialized, * queried and enabled successfully */ opal_list_append(comm->c_coll->module_list, &avail->super); - - /* copy over any of the pointers */ - COPY(avail->ac_module, comm, allgather); - COPY(avail->ac_module, comm, allgatherv); - COPY(avail->ac_module, comm, allreduce); - COPY(avail->ac_module, comm, alltoall); - COPY(avail->ac_module, comm, alltoallv); - COPY(avail->ac_module, comm, alltoallw); - COPY(avail->ac_module, comm, barrier); - COPY(avail->ac_module, comm, bcast); - COPY(avail->ac_module, comm, exscan); - COPY(avail->ac_module, comm, gather); - COPY(avail->ac_module, comm, gatherv); - COPY(avail->ac_module, comm, reduce); - COPY(avail->ac_module, comm, reduce_scatter_block); - COPY(avail->ac_module, comm, reduce_scatter); - COPY(avail->ac_module, comm, scan); - COPY(avail->ac_module, comm, scatter); - COPY(avail->ac_module, comm, scatterv); - - COPY(avail->ac_module, comm, iallgather); - COPY(avail->ac_module, comm, iallgatherv); - COPY(avail->ac_module, comm, iallreduce); - COPY(avail->ac_module, comm, ialltoall); - COPY(avail->ac_module, comm, ialltoallv); - COPY(avail->ac_module, comm, ialltoallw); - COPY(avail->ac_module, comm, ibarrier); - COPY(avail->ac_module, comm, ibcast); - COPY(avail->ac_module, comm, iexscan); - COPY(avail->ac_module, comm, igather); - COPY(avail->ac_module, comm, igatherv); - COPY(avail->ac_module, comm, ireduce); - COPY(avail->ac_module, comm, ireduce_scatter_block); - COPY(avail->ac_module, comm, ireduce_scatter); - COPY(avail->ac_module, comm, iscan); - COPY(avail->ac_module, comm, iscatter); - COPY(avail->ac_module, comm, iscatterv); - - COPY(avail->ac_module, comm, allgather_init); - COPY(avail->ac_module, comm, allgatherv_init); - COPY(avail->ac_module, comm, allreduce_init); - COPY(avail->ac_module, comm, alltoall_init); - COPY(avail->ac_module, comm, alltoallv_init); - COPY(avail->ac_module, comm, alltoallw_init); - COPY(avail->ac_module, comm, barrier_init); - COPY(avail->ac_module, comm, bcast_init); - COPY(avail->ac_module, comm, exscan_init); - COPY(avail->ac_module, comm, gather_init); - COPY(avail->ac_module, comm, gatherv_init); - COPY(avail->ac_module, comm, reduce_init); - COPY(avail->ac_module, comm, reduce_scatter_block_init); - COPY(avail->ac_module, comm, reduce_scatter_init); - COPY(avail->ac_module, comm, scan_init); - COPY(avail->ac_module, comm, scatter_init); - COPY(avail->ac_module, comm, scatterv_init); - - /* We can not reliably check if this comm has a topology - * at this time. The flags are set *after* coll_select */ - COPY(avail->ac_module, comm, neighbor_allgather); - COPY(avail->ac_module, comm, neighbor_allgatherv); - COPY(avail->ac_module, comm, neighbor_alltoall); - COPY(avail->ac_module, comm, neighbor_alltoallv); - COPY(avail->ac_module, comm, neighbor_alltoallw); - - COPY(avail->ac_module, comm, ineighbor_allgather); - COPY(avail->ac_module, comm, ineighbor_allgatherv); - COPY(avail->ac_module, comm, ineighbor_alltoall); - COPY(avail->ac_module, comm, ineighbor_alltoallv); - COPY(avail->ac_module, comm, ineighbor_alltoallw); - - COPY(avail->ac_module, comm, neighbor_allgather_init); - COPY(avail->ac_module, comm, neighbor_allgatherv_init); - COPY(avail->ac_module, comm, neighbor_alltoall_init); - COPY(avail->ac_module, comm, neighbor_alltoallv_init); - COPY(avail->ac_module, comm, neighbor_alltoallw_init); - - COPY(avail->ac_module, comm, reduce_local); - -#if OPAL_ENABLE_FT_MPI - COPY(avail->ac_module, comm, agree); - COPY(avail->ac_module, comm, iagree); -#endif } else { /* release the original module reference and the list item */ OBJ_RELEASE(avail->ac_module); @@ -291,6 +198,10 @@ int mca_coll_base_comm_select(ompi_communicator_t * comm) ((OMPI_COMM_IS_INTRA(comm)) && CHECK_NULL(which_func, comm, scan_init)) || CHECK_NULL(which_func, comm, scatter_init) || CHECK_NULL(which_func, comm, scatterv_init) || +#if OPAL_ENABLE_FT_MPI + CHECK_NULL(which_func, comm, agree) || + CHECK_NULL(which_func, comm, iagree) || +#endif /* OPAL_ENABLE_FT_MPI */ CHECK_NULL(which_func, comm, reduce_local) ) { /* TODO -- Once the topology flags are set before coll_select then * check if neighborhood collectives have been set. */ @@ -298,7 +209,7 @@ int mca_coll_base_comm_select(ompi_communicator_t * comm) opal_show_help("help-mca-coll-base.txt", "comm-select:no-function-available", true, which_func); - mca_coll_base_comm_unselect(comm); + mca_coll_base_comm_unselect(comm); return OMPI_ERR_NOT_FOUND; } return OMPI_SUCCESS; diff --git a/ompi/mca/coll/base/coll_base_comm_unselect.c b/ompi/mca/coll/base/coll_base_comm_unselect.c index 9608b498cba..aca7bb23095 100644 --- a/ompi/mca/coll/base/coll_base_comm_unselect.c +++ b/ompi/mca/coll/base/coll_base_comm_unselect.c @@ -17,6 +17,7 @@ * Copyright (c) 2017 IBM Corporation. All rights reserved. * Copyright (c) 2017 FUJITSU LIMITED. All rights reserved. * Copyright (c) 2020 BULL S.A.S. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -37,113 +38,120 @@ #include "ompi/mca/coll/base/base.h" #include "ompi/mca/coll/base/coll_base_util.h" -#define CLOSE(comm, func) \ +#if OPAL_ENABLE_DEBUG +#define CHECK_CLEAN_COLL(comm, func) \ do { \ - if (NULL != comm->c_coll->coll_ ## func ## _module) { \ - if (NULL != comm->c_coll->coll_ ## func ## _module->coll_module_disable) { \ - comm->c_coll->coll_ ## func ## _module->coll_module_disable( \ - comm->c_coll->coll_ ## func ## _module, comm); \ - } \ - OBJ_RELEASE(comm->c_coll->coll_ ## func ## _module); \ - comm->c_coll->coll_## func = NULL; \ - comm->c_coll->coll_## func ## _module = NULL; \ + if (NULL != comm->c_coll->coll_ ## func ## _module || \ + NULL != comm->c_coll->coll_ ## func ) { \ + opal_output_verbose(10, ompi_coll_base_framework.framework_output, \ + "coll:base:comm_unselect: Comm %p (%s) has a left over %s collective during cleanup", \ + (void*)comm, comm->c_name, #func); \ } \ } while (0) +#else +#define CHECK_CLEAN_COLL(comm, func) +#endif /* OPAL_ENABLE_DEBUG */ int mca_coll_base_comm_unselect(ompi_communicator_t * comm) { opal_list_item_t *item; - CLOSE(comm, allgather); - CLOSE(comm, allgatherv); - CLOSE(comm, allreduce); - CLOSE(comm, alltoall); - CLOSE(comm, alltoallv); - CLOSE(comm, alltoallw); - CLOSE(comm, barrier); - CLOSE(comm, bcast); - CLOSE(comm, exscan); - CLOSE(comm, gather); - CLOSE(comm, gatherv); - CLOSE(comm, reduce); - CLOSE(comm, reduce_scatter_block); - CLOSE(comm, reduce_scatter); - CLOSE(comm, scan); - CLOSE(comm, scatter); - CLOSE(comm, scatterv); - - CLOSE(comm, iallgather); - CLOSE(comm, iallgatherv); - CLOSE(comm, iallreduce); - CLOSE(comm, ialltoall); - CLOSE(comm, ialltoallv); - CLOSE(comm, ialltoallw); - CLOSE(comm, ibarrier); - CLOSE(comm, ibcast); - CLOSE(comm, iexscan); - CLOSE(comm, igather); - CLOSE(comm, igatherv); - CLOSE(comm, ireduce); - CLOSE(comm, ireduce_scatter_block); - CLOSE(comm, ireduce_scatter); - CLOSE(comm, iscan); - CLOSE(comm, iscatter); - CLOSE(comm, iscatterv); - - CLOSE(comm, allgather_init); - CLOSE(comm, allgatherv_init); - CLOSE(comm, allreduce_init); - CLOSE(comm, alltoall_init); - CLOSE(comm, alltoallv_init); - CLOSE(comm, alltoallw_init); - CLOSE(comm, barrier_init); - CLOSE(comm, bcast_init); - CLOSE(comm, exscan_init); - CLOSE(comm, gather_init); - CLOSE(comm, gatherv_init); - CLOSE(comm, reduce_init); - CLOSE(comm, reduce_scatter_block_init); - CLOSE(comm, reduce_scatter_init); - CLOSE(comm, scan_init); - CLOSE(comm, scatter_init); - CLOSE(comm, scatterv_init); - - CLOSE(comm, neighbor_allgather); - CLOSE(comm, neighbor_allgatherv); - CLOSE(comm, neighbor_alltoall); - CLOSE(comm, neighbor_alltoallv); - CLOSE(comm, neighbor_alltoallw); - - CLOSE(comm, ineighbor_allgather); - CLOSE(comm, ineighbor_allgatherv); - CLOSE(comm, ineighbor_alltoall); - CLOSE(comm, ineighbor_alltoallv); - CLOSE(comm, ineighbor_alltoallw); - - CLOSE(comm, neighbor_allgather_init); - CLOSE(comm, neighbor_allgatherv_init); - CLOSE(comm, neighbor_alltoall_init); - CLOSE(comm, neighbor_alltoallv_init); - CLOSE(comm, neighbor_alltoallw_init); - - CLOSE(comm, reduce_local); - -#if OPAL_ENABLE_FT_MPI - CLOSE(comm, agree); - CLOSE(comm, iagree); -#endif - - for (item = opal_list_remove_first(comm->c_coll->module_list); - NULL != item; item = opal_list_remove_first(comm->c_coll->module_list)) { + /* Call module disable in the reverse order in which enable has been called + * in order to allow the modules to properly chain themselves. + */ + for (item = opal_list_remove_last(comm->c_coll->module_list); + NULL != item; item = opal_list_remove_last(comm->c_coll->module_list)) { mca_coll_base_avail_coll_t *avail = (mca_coll_base_avail_coll_t *) item; if(avail->ac_module) { + if (NULL != avail->ac_module->coll_module_disable ) { + avail->ac_module->coll_module_disable(avail->ac_module, comm); + } OBJ_RELEASE(avail->ac_module); } OBJ_RELEASE(avail); } OBJ_RELEASE(comm->c_coll->module_list); + CHECK_CLEAN_COLL(comm, allgather); + CHECK_CLEAN_COLL(comm, allgatherv); + CHECK_CLEAN_COLL(comm, allreduce); + CHECK_CLEAN_COLL(comm, alltoall); + CHECK_CLEAN_COLL(comm, alltoallv); + CHECK_CLEAN_COLL(comm, alltoallw); + CHECK_CLEAN_COLL(comm, barrier); + CHECK_CLEAN_COLL(comm, bcast); + CHECK_CLEAN_COLL(comm, exscan); + CHECK_CLEAN_COLL(comm, gather); + CHECK_CLEAN_COLL(comm, gatherv); + CHECK_CLEAN_COLL(comm, reduce); + CHECK_CLEAN_COLL(comm, reduce_scatter_block); + CHECK_CLEAN_COLL(comm, reduce_scatter); + CHECK_CLEAN_COLL(comm, scan); + CHECK_CLEAN_COLL(comm, scatter); + CHECK_CLEAN_COLL(comm, scatterv); + + CHECK_CLEAN_COLL(comm, iallgather); + CHECK_CLEAN_COLL(comm, iallgatherv); + CHECK_CLEAN_COLL(comm, iallreduce); + CHECK_CLEAN_COLL(comm, ialltoall); + CHECK_CLEAN_COLL(comm, ialltoallv); + CHECK_CLEAN_COLL(comm, ialltoallw); + CHECK_CLEAN_COLL(comm, ibarrier); + CHECK_CLEAN_COLL(comm, ibcast); + CHECK_CLEAN_COLL(comm, iexscan); + CHECK_CLEAN_COLL(comm, igather); + CHECK_CLEAN_COLL(comm, igatherv); + CHECK_CLEAN_COLL(comm, ireduce); + CHECK_CLEAN_COLL(comm, ireduce_scatter_block); + CHECK_CLEAN_COLL(comm, ireduce_scatter); + CHECK_CLEAN_COLL(comm, iscan); + CHECK_CLEAN_COLL(comm, iscatter); + CHECK_CLEAN_COLL(comm, iscatterv); + + CHECK_CLEAN_COLL(comm, allgather_init); + CHECK_CLEAN_COLL(comm, allgatherv_init); + CHECK_CLEAN_COLL(comm, allreduce_init); + CHECK_CLEAN_COLL(comm, alltoall_init); + CHECK_CLEAN_COLL(comm, alltoallv_init); + CHECK_CLEAN_COLL(comm, alltoallw_init); + CHECK_CLEAN_COLL(comm, barrier_init); + CHECK_CLEAN_COLL(comm, bcast_init); + CHECK_CLEAN_COLL(comm, exscan_init); + CHECK_CLEAN_COLL(comm, gather_init); + CHECK_CLEAN_COLL(comm, gatherv_init); + CHECK_CLEAN_COLL(comm, reduce_init); + CHECK_CLEAN_COLL(comm, reduce_scatter_block_init); + CHECK_CLEAN_COLL(comm, reduce_scatter_init); + CHECK_CLEAN_COLL(comm, scan_init); + CHECK_CLEAN_COLL(comm, scatter_init); + CHECK_CLEAN_COLL(comm, scatterv_init); + + CHECK_CLEAN_COLL(comm, neighbor_allgather); + CHECK_CLEAN_COLL(comm, neighbor_allgatherv); + CHECK_CLEAN_COLL(comm, neighbor_alltoall); + CHECK_CLEAN_COLL(comm, neighbor_alltoallv); + CHECK_CLEAN_COLL(comm, neighbor_alltoallw); + + CHECK_CLEAN_COLL(comm, ineighbor_allgather); + CHECK_CLEAN_COLL(comm, ineighbor_allgatherv); + CHECK_CLEAN_COLL(comm, ineighbor_alltoall); + CHECK_CLEAN_COLL(comm, ineighbor_alltoallv); + CHECK_CLEAN_COLL(comm, ineighbor_alltoallw); + + CHECK_CLEAN_COLL(comm, neighbor_allgather_init); + CHECK_CLEAN_COLL(comm, neighbor_allgatherv_init); + CHECK_CLEAN_COLL(comm, neighbor_alltoall_init); + CHECK_CLEAN_COLL(comm, neighbor_alltoallv_init); + CHECK_CLEAN_COLL(comm, neighbor_alltoallw_init); + + CHECK_CLEAN_COLL(comm, reduce_local); + +#if OPAL_ENABLE_FT_MPI + CHECK_CLEAN_COLL(comm, agree); + CHECK_CLEAN_COLL(comm, iagree); +#endif + free(comm->c_coll); comm->c_coll = NULL; diff --git a/ompi/mca/coll/base/help-mca-coll-base.txt b/ompi/mca/coll/base/help-mca-coll-base.txt index d6e0071fa7a..f5afb0ce19b 100644 --- a/ompi/mca/coll/base/help-mca-coll-base.txt +++ b/ompi/mca/coll/base/help-mca-coll-base.txt @@ -10,7 +10,8 @@ # University of Stuttgart. All rights reserved. # Copyright (c) 2004-2005 The Regents of the University of California. # All rights reserved. -# Copyright (c) 2015 Cisco Systems, Inc. All rights reserved. +# Copyright (c) 2015 Cisco Systems, Inc. All rights reserved. +# Copyright (c) 2024 NVIDIA Corporation. All rights reserved. # $COPYRIGHT$ # # Additional copyrights may follow @@ -45,3 +46,10 @@ using it was destroyed. This is somewhat unusual: the module itself may be at fault, or this may be a symptom of another issue (e.g., a memory problem). +# +[comm-select:missing collective] +%s %s collective require support from a prior collective module for the fallback +case. Such support is missing in this run. +hostname: %s +priority: %d +# diff --git a/ompi/mca/coll/basic/coll_basic.h b/ompi/mca/coll/basic/coll_basic.h index 3cb36692e47..8e38655be00 100644 --- a/ompi/mca/coll/basic/coll_basic.h +++ b/ompi/mca/coll/basic/coll_basic.h @@ -53,9 +53,6 @@ BEGIN_C_DECLS *mca_coll_basic_comm_query(struct ompi_communicator_t *comm, int *priority); - int mca_coll_basic_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - int mca_coll_basic_allgather_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, diff --git a/ompi/mca/coll/basic/coll_basic_module.c b/ompi/mca/coll/basic/coll_basic_module.c index 6391b735b5d..9913a3cda26 100644 --- a/ompi/mca/coll/basic/coll_basic_module.c +++ b/ompi/mca/coll/basic/coll_basic_module.c @@ -36,6 +36,13 @@ #include "coll_basic.h" +static int +mca_coll_basic_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); +static int +mca_coll_basic_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); + /* * Initial query function that is invoked during MPI_INIT, allowing * this component to disqualify itself if it doesn't support the @@ -50,7 +57,6 @@ mca_coll_basic_init_query(bool enable_progress_threads, return OMPI_SUCCESS; } - /* * Invoked when there's a new communicator that has been created. * Look at the communicator and decide which set of functions and @@ -70,82 +76,19 @@ mca_coll_basic_comm_query(struct ompi_communicator_t *comm, /* Choose whether to use [intra|inter], and [linear|log]-based * algorithms. */ basic_module->super.coll_module_enable = mca_coll_basic_module_enable; - - if (OMPI_COMM_IS_INTER(comm)) { - basic_module->super.coll_allgather = mca_coll_basic_allgather_inter; - basic_module->super.coll_allgatherv = mca_coll_basic_allgatherv_inter; - basic_module->super.coll_allreduce = mca_coll_basic_allreduce_inter; - basic_module->super.coll_alltoall = mca_coll_basic_alltoall_inter; - basic_module->super.coll_alltoallv = mca_coll_basic_alltoallv_inter; - basic_module->super.coll_alltoallw = mca_coll_basic_alltoallw_inter; - basic_module->super.coll_barrier = mca_coll_basic_barrier_inter_lin; - basic_module->super.coll_bcast = mca_coll_basic_bcast_lin_inter; - basic_module->super.coll_exscan = NULL; - basic_module->super.coll_gather = mca_coll_basic_gather_inter; - basic_module->super.coll_gatherv = mca_coll_basic_gatherv_inter; - basic_module->super.coll_reduce = mca_coll_basic_reduce_lin_inter; - basic_module->super.coll_reduce_scatter_block = mca_coll_basic_reduce_scatter_block_inter; - basic_module->super.coll_reduce_scatter = mca_coll_basic_reduce_scatter_inter; - basic_module->super.coll_scan = NULL; - basic_module->super.coll_scatter = mca_coll_basic_scatter_inter; - basic_module->super.coll_scatterv = mca_coll_basic_scatterv_inter; - } else if (ompi_comm_size(comm) <= mca_coll_basic_crossover) { - basic_module->super.coll_allgather = ompi_coll_base_allgather_intra_basic_linear; - basic_module->super.coll_allgatherv = ompi_coll_base_allgatherv_intra_basic_default; - basic_module->super.coll_allreduce = mca_coll_basic_allreduce_intra; - basic_module->super.coll_alltoall = ompi_coll_base_alltoall_intra_basic_linear; - basic_module->super.coll_alltoallv = ompi_coll_base_alltoallv_intra_basic_linear; - basic_module->super.coll_alltoallw = mca_coll_basic_alltoallw_intra; - basic_module->super.coll_barrier = ompi_coll_base_barrier_intra_basic_linear; - basic_module->super.coll_bcast = ompi_coll_base_bcast_intra_basic_linear; - basic_module->super.coll_exscan = ompi_coll_base_exscan_intra_linear; - basic_module->super.coll_gather = ompi_coll_base_gather_intra_basic_linear; - basic_module->super.coll_gatherv = mca_coll_basic_gatherv_intra; - basic_module->super.coll_reduce = ompi_coll_base_reduce_intra_basic_linear; - basic_module->super.coll_reduce_scatter_block = mca_coll_basic_reduce_scatter_block_intra; - basic_module->super.coll_reduce_scatter = mca_coll_basic_reduce_scatter_intra; - basic_module->super.coll_scan = ompi_coll_base_scan_intra_linear; - basic_module->super.coll_scatter = ompi_coll_base_scatter_intra_basic_linear; - basic_module->super.coll_scatterv = mca_coll_basic_scatterv_intra; - } else { - basic_module->super.coll_allgather = ompi_coll_base_allgather_intra_basic_linear; - basic_module->super.coll_allgatherv = ompi_coll_base_allgatherv_intra_basic_default; - basic_module->super.coll_allreduce = mca_coll_basic_allreduce_intra; - basic_module->super.coll_alltoall = ompi_coll_base_alltoall_intra_basic_linear; - basic_module->super.coll_alltoallv = ompi_coll_base_alltoallv_intra_basic_linear; - basic_module->super.coll_alltoallw = mca_coll_basic_alltoallw_intra; - basic_module->super.coll_barrier = mca_coll_basic_barrier_intra_log; - basic_module->super.coll_bcast = mca_coll_basic_bcast_log_intra; - basic_module->super.coll_exscan = ompi_coll_base_exscan_intra_linear; - basic_module->super.coll_gather = ompi_coll_base_gather_intra_basic_linear; - basic_module->super.coll_gatherv = mca_coll_basic_gatherv_intra; - basic_module->super.coll_reduce = mca_coll_basic_reduce_log_intra; - basic_module->super.coll_reduce_scatter_block = mca_coll_basic_reduce_scatter_block_intra; - basic_module->super.coll_reduce_scatter = mca_coll_basic_reduce_scatter_intra; - basic_module->super.coll_scan = ompi_coll_base_scan_intra_linear; - basic_module->super.coll_scatter = ompi_coll_base_scatter_intra_basic_linear; - basic_module->super.coll_scatterv = mca_coll_basic_scatterv_intra; - } - - /* These functions will return an error code if comm does not have a virtual topology */ - basic_module->super.coll_neighbor_allgather = mca_coll_basic_neighbor_allgather; - basic_module->super.coll_neighbor_allgatherv = mca_coll_basic_neighbor_allgatherv; - basic_module->super.coll_neighbor_alltoall = mca_coll_basic_neighbor_alltoall; - basic_module->super.coll_neighbor_alltoallv = mca_coll_basic_neighbor_alltoallv; - basic_module->super.coll_neighbor_alltoallw = mca_coll_basic_neighbor_alltoallw; - - basic_module->super.coll_reduce_local = mca_coll_base_reduce_local; - -#if OPAL_ENABLE_FT_MPI - /* Default to some shim mappings over allreduce */ - basic_module->super.coll_agree = ompi_coll_base_agree_noft; - basic_module->super.coll_iagree = ompi_coll_base_iagree_noft; -#endif + basic_module->super.coll_module_disable = mca_coll_basic_module_disable; return &(basic_module->super); } +#define BASIC_INSTALL_COLL_API(__comm, __module, __api, __coll) \ + do \ + { \ + (__module)->super.coll_##__api = __coll; \ + MCA_COLL_INSTALL_API(__comm, __api, __coll, &__module->super, "basic"); \ + } while (0) + /* * Init module on the communicator */ @@ -153,12 +96,158 @@ int mca_coll_basic_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm) { + mca_coll_basic_module_t *basic_module = (mca_coll_basic_module_t*)module; + /* prepare the placeholder for the array of request* */ module->base_data = OBJ_NEW(mca_coll_base_comm_t); if (NULL == module->base_data) { return OMPI_ERROR; } + if (OMPI_COMM_IS_INTER(comm)) + { + BASIC_INSTALL_COLL_API(comm, basic_module, allgather, mca_coll_basic_allgather_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, allgatherv, mca_coll_basic_allgatherv_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, allreduce, mca_coll_basic_allreduce_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, alltoall, mca_coll_basic_alltoall_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, alltoallv, mca_coll_basic_alltoallv_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, alltoallw, mca_coll_basic_alltoallw_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, barrier, mca_coll_basic_barrier_inter_lin); + BASIC_INSTALL_COLL_API(comm, basic_module, bcast, mca_coll_basic_bcast_lin_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, gather, mca_coll_basic_gather_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, gatherv, mca_coll_basic_gatherv_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce, mca_coll_basic_reduce_lin_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce_scatter_block, mca_coll_basic_reduce_scatter_block_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce_scatter, mca_coll_basic_reduce_scatter_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, scatter, mca_coll_basic_scatter_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, scatterv, mca_coll_basic_scatterv_inter); + } + else { + if (ompi_comm_size(comm) <= mca_coll_basic_crossover) { + BASIC_INSTALL_COLL_API(comm, basic_module, barrier, ompi_coll_base_barrier_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, bcast, ompi_coll_base_bcast_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce, ompi_coll_base_reduce_intra_basic_linear); + } else { + BASIC_INSTALL_COLL_API(comm, basic_module, barrier, mca_coll_basic_barrier_intra_log); + BASIC_INSTALL_COLL_API(comm, basic_module, bcast, mca_coll_basic_bcast_log_intra); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce, mca_coll_basic_reduce_log_intra); + } + BASIC_INSTALL_COLL_API(comm, basic_module, allgather, ompi_coll_base_allgather_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, allgatherv, ompi_coll_base_allgatherv_intra_basic_default); + BASIC_INSTALL_COLL_API(comm, basic_module, allreduce, mca_coll_basic_allreduce_intra); + BASIC_INSTALL_COLL_API(comm, basic_module, alltoall, ompi_coll_base_alltoall_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, alltoallv, ompi_coll_base_alltoallv_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, alltoallw, mca_coll_basic_alltoallw_intra); + BASIC_INSTALL_COLL_API(comm, basic_module, exscan, ompi_coll_base_exscan_intra_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, gather, ompi_coll_base_gather_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, gatherv, mca_coll_basic_gatherv_intra); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce_scatter_block, mca_coll_basic_reduce_scatter_block_intra); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce_scatter, mca_coll_basic_reduce_scatter_intra); + BASIC_INSTALL_COLL_API(comm, basic_module, scan, ompi_coll_base_scan_intra_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, scatter, ompi_coll_base_scatter_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, scatterv, mca_coll_basic_scatterv_intra); + } + /* These functions will return an error code if comm does not have a virtual topology */ + BASIC_INSTALL_COLL_API(comm, basic_module, neighbor_allgather, mca_coll_basic_neighbor_allgather); + BASIC_INSTALL_COLL_API(comm, basic_module, neighbor_allgatherv, mca_coll_basic_neighbor_allgatherv); + BASIC_INSTALL_COLL_API(comm, basic_module, neighbor_alltoall, mca_coll_basic_neighbor_alltoall); + BASIC_INSTALL_COLL_API(comm, basic_module, neighbor_alltoallv, mca_coll_basic_neighbor_alltoallv); + BASIC_INSTALL_COLL_API(comm, basic_module, neighbor_alltoallw, mca_coll_basic_neighbor_alltoallw); + + /* Default to some shim mappings over allreduce */ + BASIC_INSTALL_COLL_API(comm, basic_module, agree, ompi_coll_base_agree_noft); + BASIC_INSTALL_COLL_API(comm, basic_module, iagree, ompi_coll_base_iagree_noft); + + BASIC_INSTALL_COLL_API(comm, basic_module, reduce_local, mca_coll_base_reduce_local); + + /* All done */ + return OMPI_SUCCESS; +} + +#define BASIC_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super ) { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "basic"); \ + } \ + } while (0) + +int +mca_coll_basic_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_basic_module_t *basic_module = (mca_coll_basic_module_t*)module; + + if (OMPI_COMM_IS_INTER(comm)) + { + BASIC_UNINSTALL_COLL_API(comm, basic_module, allgather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, allgatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, allreduce); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoall); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoallv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoallw); + BASIC_UNINSTALL_COLL_API(comm, basic_module, barrier); + BASIC_UNINSTALL_COLL_API(comm, basic_module, bcast); + BASIC_UNINSTALL_COLL_API(comm, basic_module, gather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, gatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_scatter_block); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_scatter); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scatter); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scatterv); + } + else if (ompi_comm_size(comm) <= mca_coll_basic_crossover) + { + BASIC_UNINSTALL_COLL_API(comm, basic_module, allgather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, allgatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, allreduce); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoall); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoallv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoallw); + BASIC_UNINSTALL_COLL_API(comm, basic_module, barrier); + BASIC_UNINSTALL_COLL_API(comm, basic_module, bcast); + BASIC_UNINSTALL_COLL_API(comm, basic_module, exscan); + BASIC_UNINSTALL_COLL_API(comm, basic_module, gather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, gatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_scatter_block); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_scatter); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scan); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scatter); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scatterv); + } + else + { + BASIC_UNINSTALL_COLL_API(comm, basic_module, allgather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, allgatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, allreduce); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoall); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoallv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoallw); + BASIC_UNINSTALL_COLL_API(comm, basic_module, barrier); + BASIC_UNINSTALL_COLL_API(comm, basic_module, bcast); + BASIC_UNINSTALL_COLL_API(comm, basic_module, exscan); + BASIC_UNINSTALL_COLL_API(comm, basic_module, gather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, gatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_scatter_block); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_scatter); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scan); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scatter); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scatterv); + } + + /* These functions will return an error code if comm does not have a virtual topology */ + BASIC_UNINSTALL_COLL_API(comm, basic_module, neighbor_allgather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, neighbor_allgatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, neighbor_alltoall); + BASIC_UNINSTALL_COLL_API(comm, basic_module, neighbor_alltoallv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, neighbor_alltoallw); + + BASIC_UNINSTALL_COLL_API(comm, basic_module, agree); + BASIC_UNINSTALL_COLL_API(comm, basic_module, iagree); + + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_local); /* All done */ return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/coll.h b/ompi/mca/coll/coll.h index 4908ece2d60..dd9d8a42a38 100644 --- a/ompi/mca/coll/coll.h +++ b/ompi/mca/coll/coll.h @@ -20,6 +20,7 @@ * Copyright (c) 2016-2017 IBM Corporation. All rights reserved. * Copyright (c) 2017 FUJITSU LIMITED. All rights reserved. * Copyright (c) 2020 BULL S.A.S. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -80,6 +81,7 @@ #include "opal/mca/base/base.h" #include "ompi/request/request.h" +#include "ompi/mca/coll/base/base.h" BEGIN_C_DECLS @@ -817,6 +819,29 @@ typedef struct mca_coll_base_comm_coll_t mca_coll_base_comm_coll_t; /* ******************************************************************** */ +#define MCA_COLL_SAVE_API(__comm, __api, __fct, __bmodule, __c_name) \ + do \ + { \ + OPAL_OUTPUT_VERBOSE((50, ompi_coll_base_framework.framework_output, \ + "Save %s collective in comm %p (%s) %p into %s:%s", \ + #__api, (void *)__comm, __comm->c_name, \ + (void*)(uintptr_t)__comm->c_coll->coll_##__api, \ + __c_name, #__fct)); \ + __fct = __comm->c_coll->coll_##__api; \ + __bmodule = __comm->c_coll->coll_##__api##_module; \ + } while (0) + +#define MCA_COLL_INSTALL_API(__comm, __api, __fct, __bmodule, __c_name) \ + do \ + { \ + OPAL_OUTPUT_VERBOSE((50, ompi_coll_base_framework.framework_output, \ + "Replace %s collective in comm %p (%s) from %p to %s:%s(%p)", \ + #__api, (void *)__comm, __comm->c_name, \ + (void*)(uintptr_t)__comm->c_coll->coll_##__api, \ + __c_name, #__fct, (void*)(uintptr_t)__fct)); \ + __comm->c_coll->coll_##__api = __fct; \ + __comm->c_coll->coll_##__api##_module = __bmodule; \ + } while (0) END_C_DECLS diff --git a/ompi/mca/coll/demo/coll_demo.h b/ompi/mca/coll/demo/coll_demo.h index 20ebe728b7a..d9ee2b633e2 100644 --- a/ompi/mca/coll/demo/coll_demo.h +++ b/ompi/mca/coll/demo/coll_demo.h @@ -10,6 +10,7 @@ * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. * Copyright (c) 2008 Cisco Systems, Inc. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -30,7 +31,7 @@ BEGIN_C_DECLS /* Globally exported variables */ -OMPI_DECLSPEC extern const mca_coll_base_component_2_4_0_t mca_coll_demo_component; + OMPI_DECLSPEC extern const mca_coll_base_component_2_4_0_t mca_coll_demo_component; extern int mca_coll_demo_priority; extern int mca_coll_demo_verbose; @@ -39,87 +40,84 @@ OMPI_DECLSPEC extern const mca_coll_base_component_2_4_0_t mca_coll_demo_compone int mca_coll_demo_init_query(bool enable_progress_threads, bool enable_mpi_threads); -mca_coll_base_module_t * -mca_coll_demo_comm_query(struct ompi_communicator_t *comm, int *priority); + mca_coll_base_module_t * + mca_coll_demo_comm_query(struct ompi_communicator_t *comm, int *priority); /* Module functions */ -int mca_coll_demo_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - - int mca_coll_demo_allgather_intra(void *sbuf, int scount, + int mca_coll_demo_allgather_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_allgather_inter(void *sbuf, int scount, + int mca_coll_demo_allgather_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_allgatherv_intra(void *sbuf, int scount, + int mca_coll_demo_allgatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, - void * rbuf, int *rcounts, int *disps, + void * rbuf, const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_allgatherv_inter(void *sbuf, int scount, + int mca_coll_demo_allgatherv_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, - void * rbuf, int *rcounts, int *disps, + void * rbuf, const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_allreduce_intra(void *sbuf, void *rbuf, int count, + int mca_coll_demo_allreduce_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_allreduce_inter(void *sbuf, void *rbuf, int count, + int mca_coll_demo_allreduce_inter(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_alltoall_intra(void *sbuf, int scount, + int mca_coll_demo_alltoall_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void* rbuf, int rcount, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_alltoall_inter(void *sbuf, int scount, + int mca_coll_demo_alltoall_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void* rbuf, int rcount, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_alltoallv_intra(void *sbuf, int *scounts, int *sdisps, + int mca_coll_demo_alltoallv_intra(const void *sbuf, const int *scounts, const int *sdisps, struct ompi_datatype_t *sdtype, - void *rbuf, int *rcounts, int *rdisps, + void *rbuf, const int *rcounts, const int *rdisps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_alltoallv_inter(void *sbuf, int *scounts, int *sdisps, + int mca_coll_demo_alltoallv_inter(const void *sbuf, const int *scounts, const int *sdisps, struct ompi_datatype_t *sdtype, - void *rbuf, int *rcounts, int *rdisps, + void *rbuf, const int *rcounts, const int *rdisps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_alltoallw_intra(void *sbuf, int *scounts, int *sdisps, - struct ompi_datatype_t **sdtypes, - void *rbuf, int *rcounts, int *rdisps, - struct ompi_datatype_t **rdtypes, + int mca_coll_demo_alltoallw_intra(const void *sbuf, const int *scounts, const int *sdisps, + struct ompi_datatype_t *const *sdtypes, + void *rbuf, const int *rcounts, const int *rdisps, + struct ompi_datatype_t *const *rdtypes, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_alltoallw_inter(void *sbuf, int *scounts, int *sdisps, - struct ompi_datatype_t **sdtypes, - void *rbuf, int *rcounts, int *rdisps, - struct ompi_datatype_t **rdtypes, + int mca_coll_demo_alltoallw_inter(const void *sbuf, const int *scounts, const int *sdisps, + struct ompi_datatype_t *const *sdtypes, + void *rbuf, const int *rcounts, const int *rdisps, + struct ompi_datatype_t *const *rdtypes, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); @@ -139,109 +137,109 @@ int mca_coll_demo_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_exscan_intra(void *sbuf, void *rbuf, int count, + int mca_coll_demo_exscan_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_exscan_inter(void *sbuf, void *rbuf, int count, + int mca_coll_demo_exscan_inter(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_gather_intra(void *sbuf, int scount, + int mca_coll_demo_gather_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_gather_inter(void *sbuf, int scount, + int mca_coll_demo_gather_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_gatherv_intra(void *sbuf, int scount, + int mca_coll_demo_gatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, - int *rcounts, int *disps, + const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_gatherv_inter(void *sbuf, int scount, + int mca_coll_demo_gatherv_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, - int *rcounts, int *disps, + const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_reduce_intra(void *sbuf, void* rbuf, int count, + int mca_coll_demo_reduce_intra(const void *sbuf, void* rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_reduce_inter(void *sbuf, void* rbuf, int count, + int mca_coll_demo_reduce_inter(const void *sbuf, void* rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_reduce_scatter_intra(void *sbuf, void *rbuf, - int *rcounts, + int mca_coll_demo_reduce_scatter_intra(const void *sbuf, void *rbuf, + const int *rcounts, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_reduce_scatter_inter(void *sbuf, void *rbuf, - int *rcounts, + int mca_coll_demo_reduce_scatter_inter(const void *sbuf, void *rbuf, + const int *rcounts, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_scan_intra(void *sbuf, void *rbuf, int count, + int mca_coll_demo_scan_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_scan_inter(void *sbuf, void *rbuf, int count, + int mca_coll_demo_scan_inter(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_scatter_intra(void *sbuf, int scount, + int mca_coll_demo_scatter_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_scatter_inter(void *sbuf, int scount, + int mca_coll_demo_scatter_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_scatterv_intra(void *sbuf, int *scounts, int *disps, + int mca_coll_demo_scatterv_intra(const void *sbuf, const int *scounts, const int *disps, struct ompi_datatype_t *sdtype, void* rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_scatterv_inter(void *sbuf, int *scounts, int *disps, + int mca_coll_demo_scatterv_inter(const void *sbuf, const int *scounts, const int *disps, struct ompi_datatype_t *sdtype, void* rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); -struct mca_coll_demo_module_t { - mca_coll_base_module_t super; + struct mca_coll_demo_module_t { + mca_coll_base_module_t super; - mca_coll_base_comm_coll_t underlying; -}; -typedef struct mca_coll_demo_module_t mca_coll_demo_module_t; -OBJ_CLASS_DECLARATION(mca_coll_demo_module_t); + mca_coll_base_comm_coll_t c_coll; + }; + typedef struct mca_coll_demo_module_t mca_coll_demo_module_t; + OBJ_CLASS_DECLARATION(mca_coll_demo_module_t); END_C_DECLS diff --git a/ompi/mca/coll/demo/coll_demo_allgather.c b/ompi/mca/coll/demo/coll_demo_allgather.c index f33042a71d5..c2f2a820d45 100644 --- a/ompi/mca/coll/demo/coll_demo_allgather.c +++ b/ompi/mca/coll/demo/coll_demo_allgather.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same as MPI_Allgather() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_allgather_intra(void *sbuf, int scount, +int mca_coll_demo_allgather_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, @@ -41,9 +42,9 @@ int mca_coll_demo_allgather_intra(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo allgather_intra"); - return demo_module->underlying.coll_allgather(sbuf, scount, sdtype, rbuf, - rcount, rdtype, comm, - demo_module->underlying.coll_allgather_module); + return demo_module->c_coll.coll_allgather(sbuf, scount, sdtype, rbuf, + rcount, rdtype, comm, + demo_module->c_coll.coll_allgather_module); } @@ -54,7 +55,7 @@ int mca_coll_demo_allgather_intra(void *sbuf, int scount, * Accepts: - same as MPI_Allgather() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_allgather_inter(void *sbuf, int scount, +int mca_coll_demo_allgather_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -63,7 +64,7 @@ int mca_coll_demo_allgather_inter(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo allgather_inter"); - return demo_module->underlying.coll_allgather(sbuf, scount, sdtype, rbuf, - rcount, rdtype, comm, - demo_module->underlying.coll_allgather_module); + return demo_module->c_coll.coll_allgather(sbuf, scount, sdtype, rbuf, + rcount, rdtype, comm, + demo_module->c_coll.coll_allgather_module); } diff --git a/ompi/mca/coll/demo/coll_demo_allgatherv.c b/ompi/mca/coll/demo/coll_demo_allgatherv.c index b6503ec6865..1da300c3879 100644 --- a/ompi/mca/coll/demo/coll_demo_allgatherv.c +++ b/ompi/mca/coll/demo/coll_demo_allgatherv.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,19 +34,19 @@ * Accepts: - same as MPI_Allgatherv() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_allgatherv_intra(void *sbuf, int scount, +int mca_coll_demo_allgatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, - void * rbuf, int *rcounts, int *disps, + void * rbuf, const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo allgatherv_intra"); - return demo_module->underlying.coll_allgatherv(sbuf, scount, sdtype, - rbuf, rcounts, disps, - rdtype, comm, - demo_module->underlying.coll_allgatherv_module); + return demo_module->c_coll.coll_allgatherv(sbuf, scount, sdtype, + rbuf, rcounts, disps, + rdtype, comm, + demo_module->c_coll.coll_allgatherv_module); } @@ -56,17 +57,17 @@ int mca_coll_demo_allgatherv_intra(void *sbuf, int scount, * Accepts: - same as MPI_Allgatherv() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_allgatherv_inter(void *sbuf, int scount, - struct ompi_datatype_t *sdtype, - void * rbuf, int *rcounts, int *disps, - struct ompi_datatype_t *rdtype, +int mca_coll_demo_allgatherv_inter(const void *sbuf, int scount, + struct ompi_datatype_t *sdtype, + void * rbuf, const int *rcounts, const int *disps, + struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo allgatherv_inter"); - return demo_module->underlying.coll_allgatherv(sbuf, scount, sdtype, - rbuf, rcounts, disps, - rdtype, comm, - demo_module->underlying.coll_allgatherv_module); + return demo_module->c_coll.coll_allgatherv(sbuf, scount, sdtype, + rbuf, rcounts, disps, + rdtype, comm, + demo_module->c_coll.coll_allgatherv_module); } diff --git a/ompi/mca/coll/demo/coll_demo_allreduce.c b/ompi/mca/coll/demo/coll_demo_allreduce.c index 15975bacb1c..f8f833b1de1 100644 --- a/ompi/mca/coll/demo/coll_demo_allreduce.c +++ b/ompi/mca/coll/demo/coll_demo_allreduce.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same as MPI_Allreduce() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_allreduce_intra(void *sbuf, void *rbuf, int count, +int mca_coll_demo_allreduce_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -41,9 +42,9 @@ int mca_coll_demo_allreduce_intra(void *sbuf, void *rbuf, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo allreduce_intra"); - return demo_module->underlying.coll_allreduce(sbuf, rbuf, count, dtype, - op, comm, - demo_module->underlying.coll_allreduce_module); + return demo_module->c_coll.coll_allreduce(sbuf, rbuf, count, dtype, + op, comm, + demo_module->c_coll.coll_allreduce_module); } @@ -54,7 +55,7 @@ int mca_coll_demo_allreduce_intra(void *sbuf, void *rbuf, int count, * Accepts: - same as MPI_Allreduce() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_allreduce_inter(void *sbuf, void *rbuf, int count, +int mca_coll_demo_allreduce_inter(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -62,7 +63,7 @@ int mca_coll_demo_allreduce_inter(void *sbuf, void *rbuf, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo allreduce_inter"); - return demo_module->underlying.coll_allreduce(sbuf, rbuf, count, dtype, - op, comm, - demo_module->underlying.coll_allreduce_module); + return demo_module->c_coll.coll_allreduce(sbuf, rbuf, count, dtype, + op, comm, + demo_module->c_coll.coll_allreduce_module); } diff --git a/ompi/mca/coll/demo/coll_demo_alltoall.c b/ompi/mca/coll/demo/coll_demo_alltoall.c index d3559970121..41b5e7f4c3a 100644 --- a/ompi/mca/coll/demo/coll_demo_alltoall.c +++ b/ompi/mca/coll/demo/coll_demo_alltoall.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same as MPI_Alltoall() * Returns: - MPI_SUCCESS or an MPI error code */ -int mca_coll_demo_alltoall_intra(void *sbuf, int scount, +int mca_coll_demo_alltoall_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -42,10 +43,10 @@ int mca_coll_demo_alltoall_intra(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo alltoall_intra\n"); - return demo_module->underlying.coll_alltoall(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - comm, - demo_module->underlying.coll_alltoall_module); + return demo_module->c_coll.coll_alltoall(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + comm, + demo_module->c_coll.coll_alltoall_module); } @@ -56,7 +57,7 @@ int mca_coll_demo_alltoall_intra(void *sbuf, int scount, * Accepts: - same as MPI_Alltoall() * Returns: - MPI_SUCCESS or an MPI error code */ -int mca_coll_demo_alltoall_inter(void *sbuf, int scount, +int mca_coll_demo_alltoall_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -65,8 +66,8 @@ int mca_coll_demo_alltoall_inter(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo alltoall_inter\n"); - return demo_module->underlying.coll_alltoall(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - comm, - demo_module->underlying.coll_alltoall_module); + return demo_module->c_coll.coll_alltoall(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + comm, + demo_module->c_coll.coll_alltoall_module); } diff --git a/ompi/mca/coll/demo/coll_demo_alltoallv.c b/ompi/mca/coll/demo/coll_demo_alltoallv.c index 0e8cf13861b..512a3654fd2 100644 --- a/ompi/mca/coll/demo/coll_demo_alltoallv.c +++ b/ompi/mca/coll/demo/coll_demo_alltoallv.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -34,19 +35,19 @@ * Returns: - MPI_SUCCESS or an MPI error code */ int -mca_coll_demo_alltoallv_intra(void *sbuf, int *scounts, int *sdisps, +mca_coll_demo_alltoallv_intra(const void *sbuf, const int *scounts, const int *sdisps, struct ompi_datatype_t *sdtype, - void *rbuf, int *rcounts, int *rdisps, + void *rbuf, const int *rcounts, const int *rdisps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo alltoallv_intra"); - return demo_module->underlying.coll_alltoallv(sbuf, scounts, sdisps, - sdtype, rbuf, rcounts, - rdisps, rdtype, comm, - demo_module->underlying.coll_alltoallv_module); + return demo_module->c_coll.coll_alltoallv(sbuf, scounts, sdisps, + sdtype, rbuf, rcounts, + rdisps, rdtype, comm, + demo_module->c_coll.coll_alltoallv_module); } @@ -58,17 +59,17 @@ mca_coll_demo_alltoallv_intra(void *sbuf, int *scounts, int *sdisps, * Returns: - MPI_SUCCESS or an MPI error code */ int -mca_coll_demo_alltoallv_inter(void *sbuf, int *scounts, int *sdisps, +mca_coll_demo_alltoallv_inter(const void *sbuf, const int *scounts, const int *sdisps, struct ompi_datatype_t *sdtype, void *rbuf, - int *rcounts, int *rdisps, + const int *rcounts, const int *rdisps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo alltoallv_inter"); - return demo_module->underlying.coll_alltoallv(sbuf, scounts, sdisps, - sdtype, rbuf, rcounts, - rdisps, rdtype, comm, - demo_module->underlying.coll_alltoallv_module); + return demo_module->c_coll.coll_alltoallv(sbuf, scounts, sdisps, + sdtype, rbuf, rcounts, + rdisps, rdtype, comm, + demo_module->c_coll.coll_alltoallv_module); } diff --git a/ompi/mca/coll/demo/coll_demo_alltoallw.c b/ompi/mca/coll/demo/coll_demo_alltoallw.c index b9c29693178..de4bd8433ba 100644 --- a/ompi/mca/coll/demo/coll_demo_alltoallw.c +++ b/ompi/mca/coll/demo/coll_demo_alltoallw.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,19 +34,19 @@ * Accepts: - same as MPI_Alltoallw() * Returns: - MPI_SUCCESS or an MPI error code */ -int mca_coll_demo_alltoallw_intra(void *sbuf, int *scounts, int *sdisps, - struct ompi_datatype_t **sdtypes, - void *rbuf, int *rcounts, int *rdisps, - struct ompi_datatype_t **rdtypes, +int mca_coll_demo_alltoallw_intra(const void *sbuf, const int *scounts, const int *sdisps, + struct ompi_datatype_t *const *sdtypes, + void *rbuf, const int *rcounts, const int *rdisps, + struct ompi_datatype_t *const *rdtypes, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo alltoallw_intra"); - return demo_module->underlying.coll_alltoallw(sbuf, scounts, sdisps, - sdtypes, rbuf, rcounts, - rdisps, rdtypes, comm, - demo_module->underlying.coll_alltoallw_module); + return demo_module->c_coll.coll_alltoallw(sbuf, scounts, sdisps, + sdtypes, rbuf, rcounts, + rdisps, rdtypes, comm, + demo_module->c_coll.coll_alltoallw_module); } @@ -56,17 +57,17 @@ int mca_coll_demo_alltoallw_intra(void *sbuf, int *scounts, int *sdisps, * Accepts: - same as MPI_Alltoallw() * Returns: - MPI_SUCCESS or an MPI error code */ -int mca_coll_demo_alltoallw_inter(void *sbuf, int *scounts, int *sdisps, - struct ompi_datatype_t **sdtypes, - void *rbuf, int *rcounts, int *rdisps, - struct ompi_datatype_t **rdtypes, +int mca_coll_demo_alltoallw_inter(const void *sbuf, const int *scounts, const int *sdisps, + struct ompi_datatype_t *const *sdtypes, + void *rbuf, const int *rcounts, const int *rdisps, + struct ompi_datatype_t *const *rdtypes, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo alltoallw_inter"); - return demo_module->underlying.coll_alltoallw(sbuf, scounts, sdisps, - sdtypes, rbuf, rcounts, - rdisps, rdtypes, comm, - demo_module->underlying.coll_alltoallw_module); + return demo_module->c_coll.coll_alltoallw(sbuf, scounts, sdisps, + sdtypes, rbuf, rcounts, + rdisps, rdtypes, comm, + demo_module->c_coll.coll_alltoallw_module); } diff --git a/ompi/mca/coll/demo/coll_demo_barrier.c b/ompi/mca/coll/demo/coll_demo_barrier.c index bcede2bf5b5..7965b8856ae 100644 --- a/ompi/mca/coll/demo/coll_demo_barrier.c +++ b/ompi/mca/coll/demo/coll_demo_barrier.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -38,8 +39,8 @@ int mca_coll_demo_barrier_intra(struct ompi_communicator_t *comm, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo barrier_intra"); - return demo_module->underlying.coll_barrier(comm, - demo_module->underlying.coll_barrier_module); + return demo_module->c_coll.coll_barrier(comm, + demo_module->c_coll.coll_barrier_module); } @@ -55,6 +56,6 @@ int mca_coll_demo_barrier_inter(struct ompi_communicator_t *comm, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo barrier_inter"); - return demo_module->underlying.coll_barrier(comm, - demo_module->underlying.coll_barrier_module); + return demo_module->c_coll.coll_barrier(comm, + demo_module->c_coll.coll_barrier_module); } diff --git a/ompi/mca/coll/demo/coll_demo_bcast.c b/ompi/mca/coll/demo/coll_demo_bcast.c index 645c9e0dd62..3d42555bfc6 100644 --- a/ompi/mca/coll/demo/coll_demo_bcast.c +++ b/ompi/mca/coll/demo/coll_demo_bcast.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -40,9 +41,9 @@ int mca_coll_demo_bcast_intra(void *buff, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo bcast_intra"); - return demo_module->underlying.coll_bcast(buff, count, datatype, - root, comm, - demo_module->underlying.coll_bcast_module); + return demo_module->c_coll.coll_bcast(buff, count, datatype, + root, comm, + demo_module->c_coll.coll_bcast_module); } @@ -60,7 +61,7 @@ int mca_coll_demo_bcast_inter(void *buff, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo bcast_inter"); - return demo_module->underlying.coll_bcast(buff, count, datatype, - root, comm, - demo_module->underlying.coll_bcast_module); + return demo_module->c_coll.coll_bcast(buff, count, datatype, + root, comm, + demo_module->c_coll.coll_bcast_module); } diff --git a/ompi/mca/coll/demo/coll_demo_component.c b/ompi/mca/coll/demo/coll_demo_component.c index ee23252b2e2..80f91da046a 100644 --- a/ompi/mca/coll/demo/coll_demo_component.c +++ b/ompi/mca/coll/demo/coll_demo_component.c @@ -13,6 +13,7 @@ * Copyright (c) 2008 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2015 Los Alamos National Security, LLC. All rights * reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -105,39 +106,10 @@ static int demo_register(void) static void mca_coll_demo_module_construct(mca_coll_demo_module_t *module) { - memset(&module->underlying, 0, sizeof(mca_coll_base_comm_coll_t)); + memset(&module->c_coll, 0, sizeof(mca_coll_base_comm_coll_t)); } -#define RELEASE(module, func) \ - do { \ - if (NULL != module->underlying.coll_ ## func ## _module) { \ - OBJ_RELEASE(module->underlying.coll_ ## func ## _module); \ - } \ - } while (0) - -static void -mca_coll_demo_module_destruct(mca_coll_demo_module_t *module) -{ - RELEASE(module, allgather); - RELEASE(module, allgatherv); - RELEASE(module, allreduce); - RELEASE(module, alltoall); - RELEASE(module, alltoallv); - RELEASE(module, alltoallw); - RELEASE(module, barrier); - RELEASE(module, bcast); - RELEASE(module, exscan); - RELEASE(module, gather); - RELEASE(module, gatherv); - RELEASE(module, reduce); - RELEASE(module, reduce_scatter); - RELEASE(module, scan); - RELEASE(module, scatter); - RELEASE(module, scatterv); -} - - OBJ_CLASS_INSTANCE(mca_coll_demo_module_t, mca_coll_base_module_t, mca_coll_demo_module_construct, - mca_coll_demo_module_destruct); + NULL); diff --git a/ompi/mca/coll/demo/coll_demo_exscan.c b/ompi/mca/coll/demo/coll_demo_exscan.c index c970369d0dd..de9d469321c 100644 --- a/ompi/mca/coll/demo/coll_demo_exscan.c +++ b/ompi/mca/coll/demo/coll_demo_exscan.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same arguments as MPI_Exscan() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_exscan_intra(void *sbuf, void *rbuf, int count, +int mca_coll_demo_exscan_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -41,9 +42,9 @@ int mca_coll_demo_exscan_intra(void *sbuf, void *rbuf, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo exscan_intra"); - return demo_module->underlying.coll_exscan(sbuf, rbuf, count, dtype, - op, comm, - demo_module->underlying.coll_exscan_module); + return demo_module->c_coll.coll_exscan(sbuf, rbuf, count, dtype, + op, comm, + demo_module->c_coll.coll_exscan_module); } /* diff --git a/ompi/mca/coll/demo/coll_demo_gather.c b/ompi/mca/coll/demo/coll_demo_gather.c index 9f9840acf8f..d12fb2ad3c3 100644 --- a/ompi/mca/coll/demo/coll_demo_gather.c +++ b/ompi/mca/coll/demo/coll_demo_gather.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -32,7 +33,7 @@ * Accepts: - same arguments as MPI_Gather() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_gather_intra(void *sbuf, int scount, +int mca_coll_demo_gather_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -41,10 +42,10 @@ int mca_coll_demo_gather_intra(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo gather_intra"); - return demo_module->underlying.coll_gather(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - root, comm, - demo_module->underlying.coll_gather_module); + return demo_module->c_coll.coll_gather(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + root, comm, + demo_module->c_coll.coll_gather_module); } @@ -55,7 +56,7 @@ int mca_coll_demo_gather_intra(void *sbuf, int scount, * Accepts: - same arguments as MPI_Gather() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_gather_inter(void *sbuf, int scount, +int mca_coll_demo_gather_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -64,8 +65,8 @@ int mca_coll_demo_gather_inter(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo gather_inter"); - return demo_module->underlying.coll_gather(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - root, comm, - demo_module->underlying.coll_gather_module); + return demo_module->c_coll.coll_gather(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + root, comm, + demo_module->c_coll.coll_gather_module); } diff --git a/ompi/mca/coll/demo/coll_demo_gatherv.c b/ompi/mca/coll/demo/coll_demo_gatherv.c index f23b37a0d88..082e10153f2 100644 --- a/ompi/mca/coll/demo/coll_demo_gatherv.c +++ b/ompi/mca/coll/demo/coll_demo_gatherv.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,19 +34,19 @@ * Accepts: - same arguments as MPI_Gatherv() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_gatherv_intra(void *sbuf, int scount, +int mca_coll_demo_gatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, - void *rbuf, int *rcounts, int *disps, + void *rbuf, const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo gatherv_intra"); - return demo_module->underlying.coll_gatherv(sbuf, scount, sdtype, - rbuf, rcounts, disps, - rdtype, root, comm, - demo_module->underlying.coll_gatherv_module); + return demo_module->c_coll.coll_gatherv(sbuf, scount, sdtype, + rbuf, rcounts, disps, + rdtype, root, comm, + demo_module->c_coll.coll_gatherv_module); } @@ -56,17 +57,17 @@ int mca_coll_demo_gatherv_intra(void *sbuf, int scount, * Accepts: - same arguments as MPI_Gatherv() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_gatherv_inter(void *sbuf, int scount, +int mca_coll_demo_gatherv_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, - void *rbuf, int *rcounts, int *disps, + void *rbuf, const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo gatherv_inter"); - return demo_module->underlying.coll_gatherv(sbuf, scount, sdtype, - rbuf, rcounts, disps, - rdtype, root, comm, - demo_module->underlying.coll_gatherv_module); + return demo_module->c_coll.coll_gatherv(sbuf, scount, sdtype, + rbuf, rcounts, disps, + rdtype, root, comm, + demo_module->c_coll.coll_gatherv_module); } diff --git a/ompi/mca/coll/demo/coll_demo_module.c b/ompi/mca/coll/demo/coll_demo_module.c index 5c66c98059f..0201340d9d9 100644 --- a/ompi/mca/coll/demo/coll_demo_module.c +++ b/ompi/mca/coll/demo/coll_demo_module.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -27,71 +28,6 @@ #include "ompi/mca/coll/base/base.h" #include "coll_demo.h" -#if 0 - -/* - * Linear set of collective algorithms - */ -static const mca_coll_base_module_1_0_0_t intra = { - - /* Initialization / finalization functions */ - - mca_coll_demo_module_init, - mca_coll_demo_module_finalize, - - /* Collective function pointers */ - - mca_coll_demo_allgather_intra, - mca_coll_demo_allgatherv_intra, - mca_coll_demo_allreduce_intra, - mca_coll_demo_alltoall_intra, - mca_coll_demo_alltoallv_intra, - mca_coll_demo_alltoallw_intra, - mca_coll_demo_barrier_intra, - mca_coll_demo_bcast_intra, - NULL, /* Leave exscan blank just to force basic to be used */ - mca_coll_demo_gather_intra, - mca_coll_demo_gatherv_intra, - mca_coll_demo_reduce_intra, - mca_coll_demo_reduce_scatter_intra, - mca_coll_demo_scan_intra, - mca_coll_demo_scatter_intra, - mca_coll_demo_scatterv_intra -}; - - -/* - * Linear set of collective algorithms for intercommunicators - */ -static const mca_coll_base_module_1_0_0_t inter = { - - /* Initialization / finalization functions */ - - mca_coll_demo_module_init, - mca_coll_demo_module_finalize, - - /* Collective function pointers */ - - mca_coll_demo_allgather_inter, - mca_coll_demo_allgatherv_inter, - mca_coll_demo_allreduce_inter, - mca_coll_demo_alltoall_inter, - mca_coll_demo_alltoallv_inter, - mca_coll_demo_alltoallw_inter, - mca_coll_demo_barrier_inter, - mca_coll_demo_bcast_inter, - mca_coll_demo_exscan_inter, - mca_coll_demo_gather_inter, - mca_coll_demo_gatherv_inter, - mca_coll_demo_reduce_inter, - mca_coll_demo_reduce_scatter_inter, - NULL, - mca_coll_demo_scatter_inter, - mca_coll_demo_scatterv_inter -}; - -#endif - /* * Initial query function that is invoked during MPI_INIT, allowing * this component to disqualify itself if it doesn't support the @@ -105,6 +41,13 @@ int mca_coll_demo_init_query(bool enable_progress_threads, return OMPI_SUCCESS; } +static int +mca_coll_demo_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); +static int +mca_coll_demo_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); + /* * Invoked when there's a new communicator that has been created. * Look at the communicator and decide which set of functions and @@ -121,54 +64,38 @@ mca_coll_demo_comm_query(struct ompi_communicator_t *comm, int *priority) *priority = mca_coll_demo_priority; demo_module->super.coll_module_enable = mca_coll_demo_module_enable; - - if (OMPI_COMM_IS_INTRA(comm)) { - demo_module->super.coll_allgather = mca_coll_demo_allgather_intra; - demo_module->super.coll_allgatherv = mca_coll_demo_allgatherv_intra; - demo_module->super.coll_allreduce = mca_coll_demo_allreduce_intra; - demo_module->super.coll_alltoall = mca_coll_demo_alltoall_intra; - demo_module->super.coll_alltoallv = mca_coll_demo_alltoallv_intra; - demo_module->super.coll_alltoallw = mca_coll_demo_alltoallw_intra; - demo_module->super.coll_barrier = mca_coll_demo_barrier_intra; - demo_module->super.coll_bcast = mca_coll_demo_bcast_intra; - demo_module->super.coll_exscan = mca_coll_demo_exscan_intra; - demo_module->super.coll_gather = mca_coll_demo_gather_intra; - demo_module->super.coll_gatherv = mca_coll_demo_gatherv_intra; - demo_module->super.coll_reduce = mca_coll_demo_reduce_intra; - demo_module->super.coll_reduce_scatter = mca_coll_demo_reduce_scatter_intra; - demo_module->super.coll_scan = mca_coll_demo_scan_intra; - demo_module->super.coll_scatter = mca_coll_demo_scatter_intra; - demo_module->super.coll_scatterv = mca_coll_demo_scatterv_intra; - } else { - demo_module->super.coll_allgather = mca_coll_demo_allgather_inter; - demo_module->super.coll_allgatherv = mca_coll_demo_allgatherv_inter; - demo_module->super.coll_allreduce = mca_coll_demo_allreduce_inter; - demo_module->super.coll_alltoall = mca_coll_demo_alltoall_inter; - demo_module->super.coll_alltoallv = mca_coll_demo_alltoallv_inter; - demo_module->super.coll_alltoallw = mca_coll_demo_alltoallw_inter; - demo_module->super.coll_barrier = mca_coll_demo_barrier_inter; - demo_module->super.coll_bcast = mca_coll_demo_bcast_inter; - demo_module->super.coll_exscan = NULL; - demo_module->super.coll_gather = mca_coll_demo_gather_inter; - demo_module->super.coll_gatherv = mca_coll_demo_gatherv_inter; - demo_module->super.coll_reduce = mca_coll_demo_reduce_inter; - demo_module->super.coll_reduce_scatter = mca_coll_demo_reduce_scatter_inter; - demo_module->super.coll_scan = NULL; - demo_module->super.coll_scatter = mca_coll_demo_scatter_inter; - demo_module->super.coll_scatterv = mca_coll_demo_scatterv_inter; - } + demo_module->super.coll_module_disable = mca_coll_demo_module_disable; return &(demo_module->super); } -#define COPY(comm, module, func) \ - do { \ - module->underlying.coll_ ## func = comm->c_coll->coll_ ## func; \ - module->underlying.coll_ ## func = comm->c_coll->coll_ ## func; \ - if (NULL != module->underlying.coll_ ## func ## _module) { \ - OBJ_RETAIN(module->underlying.coll_ ## func ## _module); \ - } \ - } while (0) +#define DEMO_INSTALL_COLL_API(__comm, __module, __api, __func) \ + do \ + { \ + if (__comm->c_coll->coll_##__api) \ + { \ + /* save the current selected collective */ \ + MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "demo"); \ + /* install our own */ \ + MCA_COLL_INSTALL_API(__comm, __api, __func, &__module->super, "demo"); \ + } \ + else \ + { \ + opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \ + "demo", #__api, ompi_process_info.nodename, \ + mca_coll_demo_priority); \ + } \ + } while (0) + +#define DEMO_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (&(__module)->super == __comm->c_coll->coll_##__api##_module) \ + { \ + /* put back the original collective */ \ + MCA_COLL_INSTALL_API(__comm, __api, __module->c_coll.coll_##__api, __module->c_coll.coll_##__api##_module, "demo"); \ + } \ + } while (0) int mca_coll_demo_module_enable(mca_coll_base_module_t *module, @@ -180,23 +107,70 @@ mca_coll_demo_module_enable(mca_coll_base_module_t *module, printf("Hello! This is the \"demo\" coll component. I'll be your coll component\ntoday. Please tip your waitresses well.\n"); } - /* save the old pointers */ - COPY(comm, demo_module, allgather); - COPY(comm, demo_module, allgatherv); - COPY(comm, demo_module, allreduce); - COPY(comm, demo_module, alltoall); - COPY(comm, demo_module, alltoallv); - COPY(comm, demo_module, alltoallw); - COPY(comm, demo_module, barrier); - COPY(comm, demo_module, bcast); - COPY(comm, demo_module, exscan); - COPY(comm, demo_module, gather); - COPY(comm, demo_module, gatherv); - COPY(comm, demo_module, reduce); - COPY(comm, demo_module, reduce_scatter); - COPY(comm, demo_module, scan); - COPY(comm, demo_module, scatter); - COPY(comm, demo_module, scatterv); + if (OMPI_COMM_IS_INTRA(comm)) + { + DEMO_INSTALL_COLL_API(comm, demo_module, allgather, mca_coll_demo_allgather_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, allgatherv, mca_coll_demo_allgatherv_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, allreduce, mca_coll_demo_allreduce_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, alltoall, mca_coll_demo_alltoall_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, alltoallv, mca_coll_demo_alltoallv_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, alltoallw, mca_coll_demo_alltoallw_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, barrier, mca_coll_demo_barrier_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, bcast, mca_coll_demo_bcast_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, exscan, mca_coll_demo_exscan_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, gather, mca_coll_demo_gather_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, gatherv, mca_coll_demo_gatherv_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, reduce, mca_coll_demo_reduce_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, reduce_scatter, mca_coll_demo_reduce_scatter_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, scan, mca_coll_demo_scan_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, scatter, mca_coll_demo_scatter_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, scatterv, mca_coll_demo_scatterv_intra); + } + else + { + DEMO_INSTALL_COLL_API(comm, demo_module, allgather, mca_coll_demo_allgather_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, allgatherv, mca_coll_demo_allgatherv_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, allreduce, mca_coll_demo_allreduce_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, alltoall, mca_coll_demo_alltoall_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, alltoallv, mca_coll_demo_alltoallv_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, alltoallw, mca_coll_demo_alltoallw_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, barrier, mca_coll_demo_barrier_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, bcast, mca_coll_demo_bcast_inter); + /* Skip the exscan for inter-comms, it is not supported */ + DEMO_INSTALL_COLL_API(comm, demo_module, gather, mca_coll_demo_gather_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, gatherv, mca_coll_demo_gatherv_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, reduce, mca_coll_demo_reduce_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, reduce_scatter, mca_coll_demo_reduce_scatter_inter); + /* Skip the scan for inter-comms, it is not supported */ + DEMO_INSTALL_COLL_API(comm, demo_module, scatter, mca_coll_demo_scatter_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, scatterv, mca_coll_demo_scatterv_inter); + } + return OMPI_SUCCESS; +} + +static int +mca_coll_demo_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; + + /* put back the old pointers */ + DEMO_UNINSTALL_COLL_API(comm, demo_module, allgather); + DEMO_UNINSTALL_COLL_API(comm, demo_module, allgatherv); + DEMO_UNINSTALL_COLL_API(comm, demo_module, allreduce); + DEMO_UNINSTALL_COLL_API(comm, demo_module, alltoall); + DEMO_UNINSTALL_COLL_API(comm, demo_module, alltoallv); + DEMO_UNINSTALL_COLL_API(comm, demo_module, alltoallw); + DEMO_UNINSTALL_COLL_API(comm, demo_module, barrier); + DEMO_UNINSTALL_COLL_API(comm, demo_module, bcast); + DEMO_UNINSTALL_COLL_API(comm, demo_module, exscan); + DEMO_UNINSTALL_COLL_API(comm, demo_module, gather); + DEMO_UNINSTALL_COLL_API(comm, demo_module, gatherv); + DEMO_UNINSTALL_COLL_API(comm, demo_module, reduce); + DEMO_UNINSTALL_COLL_API(comm, demo_module, reduce_scatter); + DEMO_UNINSTALL_COLL_API(comm, demo_module, scan); + DEMO_UNINSTALL_COLL_API(comm, demo_module, scatter); + DEMO_UNINSTALL_COLL_API(comm, demo_module, scatterv); return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/demo/coll_demo_reduce.c b/ompi/mca/coll/demo/coll_demo_reduce.c index 6df413902b6..03875148507 100644 --- a/ompi/mca/coll/demo/coll_demo_reduce.c +++ b/ompi/mca/coll/demo/coll_demo_reduce.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same as MPI_Reduce() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_reduce_intra(void *sbuf, void *rbuf, int count, +int mca_coll_demo_reduce_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, struct ompi_communicator_t *comm, @@ -41,9 +42,9 @@ int mca_coll_demo_reduce_intra(void *sbuf, void *rbuf, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo reduce_intra"); - return demo_module->underlying.coll_reduce(sbuf, rbuf, count, dtype, - op, root, comm, - demo_module->underlying.coll_reduce_module); + return demo_module->c_coll.coll_reduce(sbuf, rbuf, count, dtype, + op, root, comm, + demo_module->c_coll.coll_reduce_module); } @@ -54,7 +55,7 @@ int mca_coll_demo_reduce_intra(void *sbuf, void *rbuf, int count, * Accepts: - same as MPI_Reduce() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_reduce_inter(void *sbuf, void *rbuf, int count, +int mca_coll_demo_reduce_inter(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, struct ompi_communicator_t *comm, @@ -62,7 +63,7 @@ int mca_coll_demo_reduce_inter(void *sbuf, void *rbuf, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo reduce_inter"); - return demo_module->underlying.coll_reduce(sbuf, rbuf, count, dtype, - op, root, comm, - demo_module->underlying.coll_reduce_module); + return demo_module->c_coll.coll_reduce(sbuf, rbuf, count, dtype, + op, root, comm, + demo_module->c_coll.coll_reduce_module); } diff --git a/ompi/mca/coll/demo/coll_demo_reduce_scatter.c b/ompi/mca/coll/demo/coll_demo_reduce_scatter.c index 438f1008b3a..e77babd73c1 100644 --- a/ompi/mca/coll/demo/coll_demo_reduce_scatter.c +++ b/ompi/mca/coll/demo/coll_demo_reduce_scatter.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same as MPI_Reduce_scatter() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_reduce_scatter_intra(void *sbuf, void *rbuf, int *rcounts, +int mca_coll_demo_reduce_scatter_intra(const void *sbuf, void *rbuf, const int *rcounts, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -41,9 +42,9 @@ int mca_coll_demo_reduce_scatter_intra(void *sbuf, void *rbuf, int *rcounts, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scatter_intra"); - return demo_module->underlying.coll_reduce_scatter(sbuf, rbuf, rcounts, - dtype, op, comm, - demo_module->underlying.coll_reduce_scatter_module); + return demo_module->c_coll.coll_reduce_scatter(sbuf, rbuf, rcounts, + dtype, op, comm, + demo_module->c_coll.coll_reduce_scatter_module); } @@ -54,7 +55,7 @@ int mca_coll_demo_reduce_scatter_intra(void *sbuf, void *rbuf, int *rcounts, * Accepts: - same arguments as MPI_Reduce_scatter() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_reduce_scatter_inter(void *sbuf, void *rbuf, int *rcounts, +int mca_coll_demo_reduce_scatter_inter(const void *sbuf, void *rbuf, const int *rcounts, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -62,7 +63,7 @@ int mca_coll_demo_reduce_scatter_inter(void *sbuf, void *rbuf, int *rcounts, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scatter_inter"); - return demo_module->underlying.coll_reduce_scatter(sbuf, rbuf, rcounts, - dtype, op, comm, - demo_module->underlying.coll_reduce_scatter_module); + return demo_module->c_coll.coll_reduce_scatter(sbuf, rbuf, rcounts, + dtype, op, comm, + demo_module->c_coll.coll_reduce_scatter_module); } diff --git a/ompi/mca/coll/demo/coll_demo_scan.c b/ompi/mca/coll/demo/coll_demo_scan.c index 90d3cb343b1..614d30e4ef8 100644 --- a/ompi/mca/coll/demo/coll_demo_scan.c +++ b/ompi/mca/coll/demo/coll_demo_scan.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same arguments as MPI_Scan() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_scan_intra(void *sbuf, void *rbuf, int count, +int mca_coll_demo_scan_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -41,9 +42,9 @@ int mca_coll_demo_scan_intra(void *sbuf, void *rbuf, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scan_intra"); - return demo_module->underlying.coll_scan(sbuf, rbuf, count, - dtype, op, comm, - demo_module->underlying.coll_scan_module); + return demo_module->c_coll.coll_scan(sbuf, rbuf, count, + dtype, op, comm, + demo_module->c_coll.coll_scan_module); } diff --git a/ompi/mca/coll/demo/coll_demo_scatter.c b/ompi/mca/coll/demo/coll_demo_scatter.c index ccc2e401df6..f012e0dfb5f 100644 --- a/ompi/mca/coll/demo/coll_demo_scatter.c +++ b/ompi/mca/coll/demo/coll_demo_scatter.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same arguments as MPI_Scatter() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_scatter_intra(void *sbuf, int scount, +int mca_coll_demo_scatter_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -43,10 +44,10 @@ int mca_coll_demo_scatter_intra(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scatter_intra"); - return demo_module->underlying.coll_scatter(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - root, comm, - demo_module->underlying.coll_scatter_module); + return demo_module->c_coll.coll_scatter(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + root, comm, + demo_module->c_coll.coll_scatter_module); } @@ -57,7 +58,7 @@ int mca_coll_demo_scatter_intra(void *sbuf, int scount, * Accepts: - same arguments as MPI_Scatter() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_scatter_inter(void *sbuf, int scount, +int mca_coll_demo_scatter_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -67,8 +68,8 @@ int mca_coll_demo_scatter_inter(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scatter_inter"); - return demo_module->underlying.coll_scatter(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - root, comm, - demo_module->underlying.coll_scatter_module); + return demo_module->c_coll.coll_scatter(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + root, comm, + demo_module->c_coll.coll_scatter_module); } diff --git a/ompi/mca/coll/demo/coll_demo_scatterv.c b/ompi/mca/coll/demo/coll_demo_scatterv.c index 3084efc0de5..8e2101a8f36 100644 --- a/ompi/mca/coll/demo/coll_demo_scatterv.c +++ b/ompi/mca/coll/demo/coll_demo_scatterv.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,8 +34,8 @@ * Accepts: - same arguments as MPI_Scatterv() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_scatterv_intra(void *sbuf, int *scounts, - int *disps, struct ompi_datatype_t *sdtype, +int mca_coll_demo_scatterv_intra(const void *sbuf, const int *scounts, + const int *disps, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, @@ -42,10 +43,10 @@ int mca_coll_demo_scatterv_intra(void *sbuf, int *scounts, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scatterv_intra"); - return demo_module->underlying.coll_scatterv(sbuf, scounts, disps, - sdtype, rbuf, rcount, - rdtype, root, comm, - demo_module->underlying.coll_scatterv_module); + return demo_module->c_coll.coll_scatterv(sbuf, scounts, disps, + sdtype, rbuf, rcount, + rdtype, root, comm, + demo_module->c_coll.coll_scatterv_module); } @@ -56,8 +57,8 @@ int mca_coll_demo_scatterv_intra(void *sbuf, int *scounts, * Accepts: - same arguments as MPI_Scatterv() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_scatterv_inter(void *sbuf, int *scounts, - int *disps, struct ompi_datatype_t *sdtype, +int mca_coll_demo_scatterv_inter(const void *sbuf, const int *scounts, + const int *disps, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, @@ -65,8 +66,8 @@ int mca_coll_demo_scatterv_inter(void *sbuf, int *scounts, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scatterv_inter"); - return demo_module->underlying.coll_scatterv(sbuf, scounts, disps, - sdtype, rbuf, rcount, - rdtype, root, comm, - demo_module->underlying.coll_scatterv_module); + return demo_module->c_coll.coll_scatterv(sbuf, scounts, disps, + sdtype, rbuf, rcount, + rdtype, root, comm, + demo_module->c_coll.coll_scatterv_module); } diff --git a/ompi/mca/coll/ftagree/coll_ftagree.h b/ompi/mca/coll/ftagree/coll_ftagree.h index 86e0d4da314..9ca5b8b9b1f 100644 --- a/ompi/mca/coll/ftagree/coll_ftagree.h +++ b/ompi/mca/coll/ftagree/coll_ftagree.h @@ -3,6 +3,7 @@ * Copyright (c) 2012-2020 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -98,9 +99,6 @@ mca_coll_base_module_t *mca_coll_ftagree_comm_query(struct ompi_communicator_t *comm, int *priority); -int mca_coll_ftagree_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - /* * Agreement algorithms */ diff --git a/ompi/mca/coll/ftagree/coll_ftagree_module.c b/ompi/mca/coll/ftagree/coll_ftagree_module.c index 6947139a7f5..755fca1b9bc 100644 --- a/ompi/mca/coll/ftagree/coll_ftagree_module.c +++ b/ompi/mca/coll/ftagree/coll_ftagree_module.c @@ -2,6 +2,7 @@ * Copyright (c) 2012-2020 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,6 +34,30 @@ mca_coll_ftagree_init_query(bool enable_progress_threads, return OMPI_SUCCESS; } +/* + * Init/Fini module on the communicator + */ +static int +mca_coll_ftagree_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + MCA_COLL_INSTALL_API(comm, agree, module->coll_agree, module, "ftagree"); + MCA_COLL_INSTALL_API(comm, iagree, module->coll_iagree, module, "ftagree"); + + /* All done */ + return OMPI_SUCCESS; +} + +static int +mca_coll_ftagree_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + MCA_COLL_INSTALL_API(comm, agree, NULL, NULL, "ftagree"); + MCA_COLL_INSTALL_API(comm, iagree, NULL, NULL, "ftagree"); + + /* All done */ + return OMPI_SUCCESS; +} /* * Invoked when there's a new communicator that has been created. @@ -76,6 +101,7 @@ mca_coll_ftagree_comm_query(struct ompi_communicator_t *comm, * algorithms. */ ftagree_module->super.coll_module_enable = mca_coll_ftagree_module_enable; + ftagree_module->super.coll_module_disable = mca_coll_ftagree_module_disable; /* This component does not provide any base collectives, * just the FT collectives. @@ -112,15 +138,3 @@ mca_coll_ftagree_comm_query(struct ompi_communicator_t *comm, return &(ftagree_module->super); } - -/* - * Init module on the communicator - */ -int -mca_coll_ftagree_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm) -{ - /* All done */ - return OMPI_SUCCESS; -} - diff --git a/ompi/mca/coll/han/coll_han.h b/ompi/mca/coll/han/coll_han.h index 29047abb317..de4018bec22 100644 --- a/ompi/mca/coll/han/coll_han.h +++ b/ompi/mca/coll/han/coll_han.h @@ -6,6 +6,7 @@ * Copyright (c) 2020-2022 Bull S.A.S. All rights reserved. * Copyright (c) Amazon.com, Inc. or its affiliates. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -276,13 +277,14 @@ typedef struct mca_coll_han_component_t { int max_dynamic_errors; } mca_coll_han_component_t; - /* * Structure used to store what is necessary for the collective operations * routines in case of fallback. */ -typedef struct mca_coll_han_single_collective_fallback_s { - union { +typedef struct mca_coll_han_single_collective_fallback_s +{ + union + { mca_coll_base_module_allgather_fn_t allgather; mca_coll_base_module_allgatherv_fn_t allgatherv; mca_coll_base_module_allreduce_fn_t allreduce; @@ -293,7 +295,7 @@ typedef struct mca_coll_han_single_collective_fallback_s { mca_coll_base_module_reduce_fn_t reduce; mca_coll_base_module_scatter_fn_t scatter; mca_coll_base_module_scatterv_fn_t scatterv; - } module_fn; + }; mca_coll_base_module_t* module; } mca_coll_han_single_collective_fallback_t; @@ -302,7 +304,8 @@ typedef struct mca_coll_han_single_collective_fallback_s { * by HAN. This structure is used as a fallback during subcommunicator * creation. */ -typedef struct mca_coll_han_collectives_fallback_s { +typedef struct mca_coll_han_collectives_fallback_s +{ mca_coll_han_single_collective_fallback_t allgather; mca_coll_han_single_collective_fallback_t allgatherv; mca_coll_han_single_collective_fallback_t allreduce; @@ -334,7 +337,7 @@ typedef struct mca_coll_han_module_t { bool is_heterogeneous; /* To be able to fallback when the cases are not supported */ - struct mca_coll_han_collectives_fallback_s fallback; + mca_coll_han_collectives_fallback_t fallback; /* To be able to fallback on reproducible algorithm */ mca_coll_base_module_reduce_fn_t reproducible_reduce; @@ -365,61 +368,61 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t); * Some defines to stick to the naming used in the other components in terms of * fallback routines */ -#define previous_allgather fallback.allgather.module_fn.allgather +#define previous_allgather fallback.allgather.allgather #define previous_allgather_module fallback.allgather.module -#define previous_allgatherv fallback.allgatherv.module_fn.allgatherv +#define previous_allgatherv fallback.allgatherv.allgatherv #define previous_allgatherv_module fallback.allgatherv.module -#define previous_allreduce fallback.allreduce.module_fn.allreduce +#define previous_allreduce fallback.allreduce.allreduce #define previous_allreduce_module fallback.allreduce.module -#define previous_barrier fallback.barrier.module_fn.barrier +#define previous_barrier fallback.barrier.barrier #define previous_barrier_module fallback.barrier.module -#define previous_bcast fallback.bcast.module_fn.bcast +#define previous_bcast fallback.bcast.bcast #define previous_bcast_module fallback.bcast.module -#define previous_reduce fallback.reduce.module_fn.reduce +#define previous_reduce fallback.reduce.reduce #define previous_reduce_module fallback.reduce.module -#define previous_gather fallback.gather.module_fn.gather +#define previous_gather fallback.gather.gather #define previous_gather_module fallback.gather.module -#define previous_gatherv fallback.gatherv.module_fn.gatherv +#define previous_gatherv fallback.gatherv.gatherv #define previous_gatherv_module fallback.gatherv.module -#define previous_scatter fallback.scatter.module_fn.scatter +#define previous_scatter fallback.scatter.scatter #define previous_scatter_module fallback.scatter.module -#define previous_scatterv fallback.scatterv.module_fn.scatterv +#define previous_scatterv fallback.scatterv.scatterv #define previous_scatterv_module fallback.scatterv.module /* macro to correctly load a fallback collective module */ -#define HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, COLL) \ - do { \ - if ( ((COMM)->c_coll->coll_ ## COLL ## _module) == (mca_coll_base_module_t*)(HANM) ) { \ - (COMM)->c_coll->coll_ ## COLL = (HANM)->previous_## COLL; \ - mca_coll_base_module_t *coll_module = (COMM)->c_coll->coll_ ## COLL ## _module; \ - (COMM)->c_coll->coll_ ## COLL ## _module = (HANM)->previous_ ## COLL ## _module; \ - OBJ_RETAIN((COMM)->c_coll->coll_ ## COLL ## _module); \ - OBJ_RELEASE(coll_module); \ - } \ - } while(0) - -/* macro to correctly load /all/ fallback collectives */ -#define HAN_LOAD_FALLBACK_COLLECTIVES(HANM, COMM) \ - do { \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, barrier); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, bcast); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatter); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatterv); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gather); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gatherv); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, reduce); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allreduce); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allgather); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allgatherv); \ +#define HAN_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, \ + __module->previous_##__api##_module, "han"); \ + /* Do not reset the fallback to NULL it will be needed */ \ + } \ + } while (0) + + /* macro to correctly load /all/ fallback collectives */ +#define HAN_LOAD_FALLBACK_COLLECTIVES(COMM, HANM) \ + do { \ + HAN_UNINSTALL_COLL_API(COMM, HANM, barrier); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, bcast); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, scatter); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, scatterv); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, gather); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, gatherv); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, reduce); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, allreduce); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, allgather); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, allgatherv); \ han_module->enabled = false; /* entire module set to pass-through from now on */ \ } while(0) diff --git a/ompi/mca/coll/han/coll_han_allgather.c b/ompi/mca/coll/han/coll_han_allgather.c index fa827dccf09..8c57aeffb5b 100644 --- a/ompi/mca/coll/han/coll_han_allgather.c +++ b/ompi/mca/coll/han/coll_han_allgather.c @@ -4,6 +4,7 @@ * reserved. * Copyright (c) 2020 Bull S.A.S. All rights reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -82,7 +83,7 @@ mca_coll_han_allgather_intra(const void *sbuf, int scount, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle allgather within this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, han_module->previous_allgather_module); } @@ -97,7 +98,7 @@ mca_coll_han_allgather_intra(const void *sbuf, int scount, if (han_module->are_ppn_imbalanced) { OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle allgather with this communicator (imbalance). Fall back on another component\n")); - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, allgather); + HAN_UNINSTALL_COLL_API(comm, han_module, allgather); return han_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, han_module->previous_allgather_module); } @@ -306,7 +307,7 @@ mca_coll_han_allgather_intra_simple(const void *sbuf, int scount, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle allgather within this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, han_module->previous_allgather_module); } @@ -320,7 +321,7 @@ mca_coll_han_allgather_intra_simple(const void *sbuf, int scount, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, allgather); + HAN_UNINSTALL_COLL_API(comm, han_module, allgather); return han_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, han_module->previous_allgather_module); } diff --git a/ompi/mca/coll/han/coll_han_allreduce.c b/ompi/mca/coll/han/coll_han_allreduce.c index 039913d7fdb..81212145d41 100644 --- a/ompi/mca/coll/han/coll_han_allreduce.c +++ b/ompi/mca/coll/han/coll_han_allreduce.c @@ -6,6 +6,7 @@ * * Copyright (c) 2020 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -109,7 +110,7 @@ mca_coll_han_allreduce_intra(const void *sbuf, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle allreduce with this communicator. Drop HAN support in this communicator and fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_allreduce(sbuf, rbuf, count, dtype, op, comm, han_module->previous_allreduce_module); } @@ -495,7 +496,7 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle allreduce with this communicator. Drop HAN support in this communicator and fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_allreduce(sbuf, rbuf, count, dtype, op, comm, han_module->previous_allreduce_module); } diff --git a/ompi/mca/coll/han/coll_han_barrier.c b/ompi/mca/coll/han/coll_han_barrier.c index 6f437135ced..19b0e38c9a7 100644 --- a/ompi/mca/coll/han/coll_han_barrier.c +++ b/ompi/mca/coll/han/coll_han_barrier.c @@ -3,6 +3,7 @@ * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2020 Bull S.A.S. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -36,10 +37,8 @@ mca_coll_han_barrier_intra_simple(struct ompi_communicator_t *comm, if( OMPI_SUCCESS != mca_coll_han_comm_create_new(comm, han_module) ) { OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle barrier with this communicator. Fall back on another component\n")); - /* Put back the fallback collective support and call it once. All - * future calls will then be automatically redirected. - */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + /* Try to put back the fallback collective support and call it once. */ + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_barrier(comm, han_module->previous_barrier_module); } diff --git a/ompi/mca/coll/han/coll_han_bcast.c b/ompi/mca/coll/han/coll_han_bcast.c index 33a3cc8af44..b543c10db38 100644 --- a/ompi/mca/coll/han/coll_han_bcast.c +++ b/ompi/mca/coll/han/coll_han_bcast.c @@ -4,6 +4,7 @@ * reserved. * Copyright (c) 2020 Bull S.A.S. All rights reserved. * Copyright (c) 2020 Cisco Systems, Inc. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -83,7 +84,7 @@ mca_coll_han_bcast_intra(void *buf, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_bcast(buf, count, dtype, root, comm, han_module->previous_bcast_module); } @@ -96,7 +97,7 @@ mca_coll_han_bcast_intra(void *buf, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, bcast); + HAN_UNINSTALL_COLL_API(comm, han_module, bcast); return han_module->previous_bcast(buf, count, dtype, root, comm, han_module->previous_bcast_module); } @@ -245,7 +246,7 @@ mca_coll_han_bcast_intra_simple(void *buf, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_bcast(buf, count, dtype, root, comm, han_module->previous_bcast_module); } @@ -258,7 +259,7 @@ mca_coll_han_bcast_intra_simple(void *buf, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, bcast); + HAN_UNINSTALL_COLL_API(comm, han_module, bcast); return han_module->previous_bcast(buf, count, dtype, root, comm, han_module->previous_bcast_module); } diff --git a/ompi/mca/coll/han/coll_han_gather.c b/ompi/mca/coll/han/coll_han_gather.c index 8aaeddb8d26..256e4e714f0 100644 --- a/ompi/mca/coll/han/coll_han_gather.c +++ b/ompi/mca/coll/han/coll_han_gather.c @@ -5,6 +5,7 @@ * Copyright (c) 2020 Bull S.A.S. All rights reserved. * Copyright (c) 2020 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -92,7 +93,7 @@ mca_coll_han_gather_intra(const void *sbuf, int scount, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle gather with this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_gather_module); } @@ -103,10 +104,8 @@ mca_coll_han_gather_intra(const void *sbuf, int scount, if (han_module->are_ppn_imbalanced) { OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle gather with this communicator (imbalance). Fall back on another component\n")); - /* Put back the fallback collective support and call it once. All - * future calls will then be automatically redirected. - */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, gather); + /* Unregister HAN gather if possible, and execute the fallback gather */ + HAN_UNINSTALL_COLL_API(comm, han_module, gather); return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_gather_module); } @@ -319,7 +318,7 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle gather with this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_gather_module); } @@ -330,10 +329,8 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount, if (han_module->are_ppn_imbalanced){ OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle gather with this communicator (imbalance). Fall back on another component\n")); - /* Put back the fallback collective support and call it once. All - * future calls will then be automatically redirected. - */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, gather); + /* Unregister HAN gather if possible, and execute the fallback gather */ + HAN_UNINSTALL_COLL_API(comm, han_module, gather); return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_gather_module); } diff --git a/ompi/mca/coll/han/coll_han_gatherv.c b/ompi/mca/coll/han/coll_han_gatherv.c index 006ab167eba..a80d93d0e92 100644 --- a/ompi/mca/coll/han/coll_han_gatherv.c +++ b/ompi/mca/coll/han/coll_han_gatherv.c @@ -7,6 +7,7 @@ * Copyright (c) 2022 IBM Corporation. All rights reserved * Copyright (c) Amazon.com, Inc. or its affiliates. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -76,7 +77,7 @@ int mca_coll_han_gatherv_intra(const void *sbuf, int scount, struct ompi_datatyp (30, mca_coll_han_component.han_output, "han cannot handle gatherv with this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_gatherv(sbuf, scount, sdtype, rbuf, rcounts, displs, rdtype, root, comm, han_module->previous_gatherv_module); } @@ -91,7 +92,7 @@ int mca_coll_han_gatherv_intra(const void *sbuf, int scount, struct ompi_datatyp /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, gatherv); + HAN_UNINSTALL_COLL_API(comm, han_module, gatherv); return han_module->previous_gatherv(sbuf, scount, sdtype, rbuf, rcounts, displs, rdtype, root, comm, han_module->previous_gatherv_module); } diff --git a/ompi/mca/coll/han/coll_han_module.c b/ompi/mca/coll/han/coll_han_module.c index 31ee2d3fb84..8be8d5a9319 100644 --- a/ompi/mca/coll/han/coll_han_module.c +++ b/ompi/mca/coll/han/coll_han_module.c @@ -7,6 +7,7 @@ * Copyright (c) 2021 Triad National Security, LLC. All rights * reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -29,15 +30,16 @@ /* * Local functions */ -static int han_module_enable(mca_coll_base_module_t * module, - struct ompi_communicator_t *comm); +static int mca_coll_han_module_enable(mca_coll_base_module_t * module, + struct ompi_communicator_t *comm); static int mca_coll_han_module_disable(mca_coll_base_module_t * module, struct ompi_communicator_t *comm); -#define CLEAN_PREV_COLL(HANDLE, NAME) \ - do { \ - (HANDLE)->fallback.NAME.module_fn.NAME = NULL; \ - (HANDLE)->fallback.NAME.module = NULL; \ +#define CLEAN_PREV_COLL(__module, __api) \ + do \ + { \ + (__module)->previous_##__api = NULL; \ + (__module)->previous_##__api##_module = NULL; \ } while (0) /* @@ -71,7 +73,6 @@ static void mca_coll_han_module_construct(mca_coll_han_module_t * module) module->enabled = true; module->recursive_free_depth = 0; - module->super.coll_module_disable = mca_coll_han_module_disable; module->cached_low_comms = NULL; module->cached_up_comms = NULL; module->cached_vranks = NULL; @@ -90,14 +91,6 @@ static void mca_coll_han_module_construct(mca_coll_han_module_t * module) han_module_clear(module); } - -#define OBJ_RELEASE_IF_NOT_NULL(obj) \ - do { \ - if (NULL != (obj)) { \ - OBJ_RELEASE(obj); \ - } \ - } while (0) - /* * Module destructor */ @@ -146,15 +139,6 @@ mca_coll_han_module_destruct(mca_coll_han_module_t * module) } } - OBJ_RELEASE_IF_NOT_NULL(module->previous_allgather_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_allreduce_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_bcast_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_gather_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_gatherv_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_reduce_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_scatterv_module); - han_module_clear(module); } @@ -214,13 +198,7 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority) return NULL; } - han_module = OBJ_NEW(mca_coll_han_module_t); - if (NULL == han_module) { - return NULL; - } - - /* All is good -- return a module */ - han_module->topologic_level = GLOBAL_COMMUNICATOR; + int topologic_level = GLOBAL_COMMUNICATOR; if (NULL != comm->super.s_info) { /* Get the info value disaqualifying coll components */ @@ -229,27 +207,31 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority) &info_str, &flag); if (flag) { - if (0 == strcmp(info_str->string, "INTER_NODE")) { - han_module->topologic_level = INTER_NODE; - } else { - han_module->topologic_level = INTRA_NODE; - } + topologic_level = strcmp(info_str->string, "INTER_NODE") ? INTRA_NODE : INTER_NODE; OBJ_RELEASE(info_str); } } if( !ompi_group_have_remote_peers(comm->c_local_group) - && INTRA_NODE != han_module->topologic_level ) { + && INTRA_NODE != topologic_level ) { /* The group only contains local processes, and this is not a * intra-node subcomm we created. Disable HAN for now */ opal_output_verbose(10, ompi_coll_base_framework.framework_output, "coll:han:comm_query (%s/%s): comm has only local processes; disqualifying myself", ompi_comm_print_cid(comm), comm->c_name); - OBJ_RELEASE(han_module); return NULL; } - han_module->super.coll_module_enable = han_module_enable; + /* All is good -- return a module */ + han_module = OBJ_NEW(mca_coll_han_module_t); + if (NULL == han_module) { + return NULL; + } + han_module->topologic_level = topologic_level; + + han_module->super.coll_module_enable = mca_coll_han_module_enable; + han_module->super.coll_module_disable = mca_coll_han_module_disable; + han_module->super.coll_alltoall = NULL; han_module->super.coll_alltoallv = NULL; han_module->super.coll_alltoallw = NULL; @@ -265,7 +247,6 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority) han_module->super.coll_bcast = mca_coll_han_bcast_intra_dynamic; han_module->super.coll_allreduce = mca_coll_han_allreduce_intra_dynamic; han_module->super.coll_allgather = mca_coll_han_allgather_intra_dynamic; - if (GLOBAL_COMMUNICATOR == han_module->topologic_level) { /* We are on the global communicator, return topological algorithms */ han_module->super.coll_allgatherv = NULL; @@ -287,57 +268,48 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority) * . ompi_communicator_t *comm * . mca_coll_han_module_t *han_module */ -#define HAN_SAVE_PREV_COLL_API(__api) \ +#define HAN_INSTALL_COLL_API(__comm, __module, __api) \ do { \ - if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \ - opal_output_verbose(1, ompi_coll_base_framework.framework_output, \ - "(%s/%s): no underlying " # __api"; disqualifying myself", \ - ompi_comm_print_cid(comm), comm->c_name); \ - goto handle_error; \ - } \ - han_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \ - han_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \ - OBJ_RETAIN(han_module->previous_ ## __api ## _module); \ - } while(0) + if( NULL != __module->super.coll_ ## __api ) { \ + if (!__comm->c_coll->coll_##__api || !__comm->c_coll->coll_##__api##_module) { \ + opal_output_verbose(10, ompi_coll_base_framework.framework_output, \ + "(%d/%s): no underlying " #__api " ; disqualifying myself", \ + __comm->c_index, __comm->c_name); \ + } else { \ + MCA_COLL_SAVE_API(__comm, __api, __module->previous_##__api, \ + __module->previous_##__api##_module, "han"); \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "han"); \ + } \ + } \ + } while (0) + +/* The HAN_UNINSTALL_COLL_API is in coll_han.h header as it is needed in several places. */ /* * Init module on the communicator */ static int -han_module_enable(mca_coll_base_module_t * module, - struct ompi_communicator_t *comm) +mca_coll_han_module_enable(mca_coll_base_module_t * module, + struct ompi_communicator_t *comm) { mca_coll_han_module_t * han_module = (mca_coll_han_module_t*) module; - HAN_SAVE_PREV_COLL_API(allgather); - HAN_SAVE_PREV_COLL_API(allgatherv); - HAN_SAVE_PREV_COLL_API(allreduce); - HAN_SAVE_PREV_COLL_API(barrier); - HAN_SAVE_PREV_COLL_API(bcast); - HAN_SAVE_PREV_COLL_API(gather); - HAN_SAVE_PREV_COLL_API(gatherv); - HAN_SAVE_PREV_COLL_API(reduce); - HAN_SAVE_PREV_COLL_API(scatter); - HAN_SAVE_PREV_COLL_API(scatterv); + HAN_INSTALL_COLL_API(comm, han_module, allgather); + HAN_INSTALL_COLL_API(comm, han_module, allgatherv); + HAN_INSTALL_COLL_API(comm, han_module, allreduce); + HAN_INSTALL_COLL_API(comm, han_module, barrier); + HAN_INSTALL_COLL_API(comm, han_module, bcast); + HAN_INSTALL_COLL_API(comm, han_module, gather); + HAN_INSTALL_COLL_API(comm, han_module, gatherv); + HAN_INSTALL_COLL_API(comm, han_module, reduce); + HAN_INSTALL_COLL_API(comm, han_module, scatter); + HAN_INSTALL_COLL_API(comm, han_module, scatterv); /* set reproducible algos */ mca_coll_han_reduce_reproducible_decision(comm, module); mca_coll_han_allreduce_reproducible_decision(comm, module); return OMPI_SUCCESS; - -handle_error: - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allgather_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allgatherv_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allreduce_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_bcast_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gather_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gatherv_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_reduce_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatterv_module); - - return OMPI_ERROR; } /* @@ -349,16 +321,16 @@ mca_coll_han_module_disable(mca_coll_base_module_t * module, { mca_coll_han_module_t * han_module = (mca_coll_han_module_t *) module; - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allgather_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allgatherv_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allreduce_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_barrier_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_bcast_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gather_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gatherv_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_reduce_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatterv_module); + HAN_UNINSTALL_COLL_API(comm, han_module, allgather); + HAN_UNINSTALL_COLL_API(comm, han_module, allgatherv); + HAN_UNINSTALL_COLL_API(comm, han_module, allreduce); + HAN_UNINSTALL_COLL_API(comm, han_module, barrier); + HAN_UNINSTALL_COLL_API(comm, han_module, bcast); + HAN_UNINSTALL_COLL_API(comm, han_module, gather); + HAN_UNINSTALL_COLL_API(comm, han_module, gatherv); + HAN_UNINSTALL_COLL_API(comm, han_module, reduce); + HAN_UNINSTALL_COLL_API(comm, han_module, scatter); + HAN_UNINSTALL_COLL_API(comm, han_module, scatterv); han_module_clear(han_module); diff --git a/ompi/mca/coll/han/coll_han_reduce.c b/ompi/mca/coll/han/coll_han_reduce.c index 53d51dca3e3..4cb5c012d47 100644 --- a/ompi/mca/coll/han/coll_han_reduce.c +++ b/ompi/mca/coll/han/coll_han_reduce.c @@ -6,6 +6,7 @@ * Copyright (c) 2022 IBM Corporation. All rights reserved * Copyright (c) 2024 Computer Architecture and VLSI Systems (CARV) * Laboratory, ICS Forth. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -92,7 +93,7 @@ mca_coll_han_reduce_intra(const void *sbuf, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle reduce with this communicator. Drop HAN support in this communicator and fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all modules */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_reduce(sbuf, rbuf, count, dtype, op, root, comm, han_module->previous_reduce_module); } @@ -106,7 +107,7 @@ mca_coll_han_reduce_intra(const void *sbuf, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, reduce); + HAN_UNINSTALL_COLL_API(comm, han_module, reduce); return han_module->previous_reduce(sbuf, rbuf, count, dtype, op, root, comm, han_module->previous_reduce_module); } @@ -311,7 +312,7 @@ mca_coll_han_reduce_intra_simple(const void *sbuf, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle reduce with this communicator. Drop HAN support in this communicator and fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_reduce(sbuf, rbuf, count, dtype, op, root, comm, han_module->previous_reduce_module); } @@ -325,7 +326,7 @@ mca_coll_han_reduce_intra_simple(const void *sbuf, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, reduce); + HAN_UNINSTALL_COLL_API(comm, han_module, reduce); return han_module->previous_reduce(sbuf, rbuf, count, dtype, op, root, comm, han_module->previous_reduce_module); } diff --git a/ompi/mca/coll/han/coll_han_scatter.c b/ompi/mca/coll/han/coll_han_scatter.c index 87201597bb4..16193d54904 100644 --- a/ompi/mca/coll/han/coll_han_scatter.c +++ b/ompi/mca/coll/han/coll_han_scatter.c @@ -3,6 +3,7 @@ * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -86,7 +87,7 @@ mca_coll_han_scatter_intra(const void *sbuf, int scount, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle scatter with this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatter_module); } @@ -100,7 +101,7 @@ mca_coll_han_scatter_intra(const void *sbuf, int scount, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, scatter); + HAN_UNINSTALL_COLL_API(comm, han_module, scatter); return han_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatter_module); } @@ -284,7 +285,7 @@ mca_coll_han_scatter_intra_simple(const void *sbuf, int scount, "han cannot handle allgather within this communicator." " Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatter_module); } @@ -294,7 +295,7 @@ mca_coll_han_scatter_intra_simple(const void *sbuf, int scount, if (han_module->are_ppn_imbalanced){ OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle scatter with this communicator. It needs to fall back on another component\n")); - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_UNINSTALL_COLL_API(comm, han_module, scatter); return han_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatter_module); } diff --git a/ompi/mca/coll/han/coll_han_scatterv.c b/ompi/mca/coll/han/coll_han_scatterv.c index 7cd7f276d3f..16b16754313 100644 --- a/ompi/mca/coll/han/coll_han_scatterv.c +++ b/ompi/mca/coll/han/coll_han_scatterv.c @@ -7,6 +7,7 @@ * Copyright (c) 2022 IBM Corporation. All rights reserved * Copyright (c) Amazon.com, Inc. or its affiliates. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -81,7 +82,7 @@ int mca_coll_han_scatterv_intra(const void *sbuf, const int *scounts, const int 30, mca_coll_han_component.han_output, "han cannot handle scatterv with this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_scatterv(sbuf, scounts, displs, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatterv_module); } @@ -96,7 +97,7 @@ int mca_coll_han_scatterv_intra(const void *sbuf, const int *scounts, const int /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, scatterv); + HAN_UNINSTALL_COLL_API(comm, han_module, scatterv); return han_module->previous_scatterv(sbuf, scounts, displs, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatterv_module); } @@ -104,7 +105,7 @@ int mca_coll_han_scatterv_intra(const void *sbuf, const int *scounts, const int OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle scatterv with this communicator (heterogeneous). Fall " "back on another component\n")); - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, scatterv); + HAN_UNINSTALL_COLL_API(comm, han_module, scatterv); return han_module->previous_scatterv(sbuf, scounts, displs, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatterv_module); } diff --git a/ompi/mca/coll/han/coll_han_subcomms.c b/ompi/mca/coll/han/coll_han_subcomms.c index 47ef348975e..d1330188f41 100644 --- a/ompi/mca/coll/han/coll_han_subcomms.c +++ b/ompi/mca/coll/han/coll_han_subcomms.c @@ -4,6 +4,7 @@ * reserved. * Copyright (c) 2020 Bull S.A.S. All rights reserved. * + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -26,19 +27,21 @@ #include "coll_han.h" #include "coll_han_dynamic.h" -#define HAN_SUBCOM_SAVE_COLLECTIVE(FALLBACKS, COMM, HANM, COLL) \ - do { \ - (FALLBACKS).COLL.module_fn.COLL = (COMM)->c_coll->coll_ ## COLL; \ - (FALLBACKS).COLL.module = (COMM)->c_coll->coll_ ## COLL ## _module; \ - (COMM)->c_coll->coll_ ## COLL = (HANM)->fallback.COLL.module_fn.COLL; \ - (COMM)->c_coll->coll_ ## COLL ## _module = (HANM)->fallback.COLL.module; \ - } while(0) - -#define HAN_SUBCOM_LOAD_COLLECTIVE(FALLBACKS, COMM, HANM, COLL) \ - do { \ - (COMM)->c_coll->coll_ ## COLL = (FALLBACKS).COLL.module_fn.COLL; \ - (COMM)->c_coll->coll_ ## COLL ## _module = (FALLBACKS).COLL.module; \ - } while(0) +#define HAN_SUBCOM_SAVE_COLLECTIVE(FALLBACKS, COMM, HANM, COLL) \ + do \ + { \ + (FALLBACKS).COLL.COLL = (COMM)->c_coll->coll_##COLL; \ + (FALLBACKS).COLL.module = (COMM)->c_coll->coll_##COLL##_module; \ + (COMM)->c_coll->coll_##COLL = (HANM)->fallback.COLL.COLL; \ + (COMM)->c_coll->coll_##COLL##_module = (HANM)->fallback.COLL.module; \ + } while (0) + +#define HAN_SUBCOM_RESTORE_COLLECTIVE(FALLBACKS, COMM, HANM, COLL) \ + do \ + { \ + (COMM)->c_coll->coll_##COLL = (FALLBACKS).COLL.COLL; \ + (COMM)->c_coll->coll_##COLL##_module = (FALLBACKS).COLL.module; \ + } while (0) /* * Routine that creates the local hierarchical sub-communicators @@ -64,8 +67,8 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, /* * We cannot use han allreduce and allgather without sub-communicators, - * but we are in the creation of the data structures for the HAN, and - * temporarily need to save back the old collective. + * but we are in the creation of the data structures for HAN, and + * temporarily need to use the old collective. * * Allgather is used to compute vranks * Allreduce is used by ompi_comm_split_type in create_intranode_comm_new @@ -90,7 +93,8 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, * outside the MPI support (with PRRTE the info will be eventually available, * but we don't want to delay anything until then). We can achieve the same * goal by using a reduction over the maximum number of peers per node among - * all participants. + * all participants, but we need to call the fallback allreduce (or we will + * call HAN's allreduce again). */ int local_procs = ompi_group_count_local_peers(comm->c_local_group); rc = comm->c_coll->coll_allreduce(MPI_IN_PLACE, &local_procs, 1, MPI_INT, @@ -100,18 +104,18 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, goto return_with_error; } if( local_procs == 1 ) { - /* restore saved collectives */ - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allreduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, bcast); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, reduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatterv); han_module->enabled = false; /* entire module set to pass-through from now on */ - return OMPI_ERR_NOT_SUPPORTED; + /* restore saved collectives */ + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allreduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, bcast); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, reduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatterv); + return OMPI_ERR_NOT_SUPPORTED; } OBJ_CONSTRUCT(&comm_info, opal_info_t); @@ -178,21 +182,22 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, */ han_module->cached_vranks = vranks; - /* Reset the saved collectives to point back to HAN */ - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allreduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, bcast); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, reduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatterv); + /* Restore the saved collectives */ + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allreduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, bcast); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, reduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatterv); OBJ_DESTRUCT(&comm_info); return OMPI_SUCCESS; return_with_error: + han_module->enabled = false; /* entire module set to pass-through from now on */ if( NULL != *low_comm ) { ompi_comm_free(low_comm); *low_comm = NULL; /* don't leave the MPI_COMM_NULL set by ompi_comm_free */ @@ -229,7 +234,7 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, /* * We cannot use han allreduce and allgather without sub-communicators, * but we are in the creation of the data structures for the HAN, and - * temporarily need to save back the old collective. + * temporarily need to use the old collective. * * Allgather is used to compute vranks * Allreduce is used by ompi_comm_split_type in create_intranode_comm_new @@ -262,15 +267,15 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, comm->c_coll->coll_allreduce_module); if( local_procs == 1 ) { /* restore saved collectives */ - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allreduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, bcast); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, reduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatterv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allreduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, bcast); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, reduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatterv); han_module->enabled = false; /* entire module set to pass-through from now on */ return OMPI_ERR_NOT_SUPPORTED; } @@ -352,15 +357,15 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, han_module->cached_vranks = vranks; /* Reset the saved collectives to point back to HAN */ - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allreduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, bcast); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, reduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatterv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allreduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, bcast); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, reduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatterv); OBJ_DESTRUCT(&comm_info); return OMPI_SUCCESS; diff --git a/ompi/mca/coll/hcoll/coll_hcoll_module.c b/ompi/mca/coll/hcoll/coll_hcoll_module.c index 0b8c2db4d48..5ca588a8154 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_module.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_module.c @@ -7,6 +7,7 @@ * Copyright (c) 2018 Cisco Systems, Inc. All rights reserved * Copyright (c) 2022 Amazon.com, Inc. or its affiliates. * All Rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -86,8 +87,6 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module) hcoll_module->previous_igatherv_module = NULL; hcoll_module->previous_ialltoall_module = NULL; hcoll_module->previous_ialltoallv_module = NULL; - - } static void mca_coll_hcoll_module_construct(mca_coll_hcoll_module_t *hcoll_module) @@ -101,8 +100,6 @@ void mca_coll_hcoll_mem_release_cb(void *buf, size_t length, hcoll_mem_unmap(buf, length, cbdata, from_alloc); } -#define OBJ_RELEASE_IF_NOT_NULL( obj ) if( NULL != (obj) ) OBJ_RELEASE( obj ); - static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module) { int context_destroyed; @@ -119,37 +116,7 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module destroy hcoll context*/ if (hcoll_module->hcoll_context != NULL){ - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_barrier_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_bcast_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allreduce_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_block_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgather_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgatherv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_gatherv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_scatterv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_alltoall_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_alltoallv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_module); - - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_ibarrier_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_ibcast_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_iallreduce_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_iallgather_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_iallgatherv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_igatherv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_ialltoall_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_ialltoallv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_ireduce_module); - - /* - OBJ_RELEASE(hcoll_module->previous_allgatherv_module); - OBJ_RELEASE(hcoll_module->previous_gather_module); - OBJ_RELEASE(hcoll_module->previous_gatherv_module); - OBJ_RELEASE(hcoll_module->previous_alltoallw_module); - OBJ_RELEASE(hcoll_module->previous_reduce_scatter_module); - OBJ_RELEASE(hcoll_module->previous_reduce_module); - */ + #if !defined(HAVE_HCOLL_CONTEXT_FREE) context_destroyed = 0; hcoll_destroy_context(hcoll_module->hcoll_context, @@ -160,52 +127,105 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module mca_coll_hcoll_module_clear(hcoll_module); } -#define HCOL_SAVE_PREV_COLL_API(__api) do {\ - hcoll_module->previous_ ## __api = comm->c_coll->coll_ ## __api;\ - hcoll_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module;\ - if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) {\ - return OMPI_ERROR;\ - }\ - OBJ_RETAIN(hcoll_module->previous_ ## __api ## _module);\ -} while(0) - +#define HCOL_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (NULL != __module->super.coll_##__api) \ + { \ + if (comm->c_coll->coll_##__api && !comm->c_coll->coll_##__api##_module) \ + { \ + /* save the current selected collective */ \ + MCA_COLL_SAVE_API(__comm, __api, hcoll_module->previous_##__api, hcoll_module->previous_##__api##_module, "hcoll"); \ + /* install our own */ \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "hcoll"); \ + } \ + } \ + } while (0) + +#define HCOL_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (&__module->super == comm->c_coll->coll_##__api##_module) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "hcoll"); \ + hcoll_module->previous_##__api = NULL; \ + hcoll_module->previous_##__api##_module = NULL; \ + } \ + } while (0) static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_module) { ompi_communicator_t *comm; comm = hcoll_module->comm; - HCOL_SAVE_PREV_COLL_API(barrier); - HCOL_SAVE_PREV_COLL_API(bcast); - HCOL_SAVE_PREV_COLL_API(allreduce); - HCOL_SAVE_PREV_COLL_API(reduce_scatter_block); - HCOL_SAVE_PREV_COLL_API(reduce_scatter); - HCOL_SAVE_PREV_COLL_API(reduce); - HCOL_SAVE_PREV_COLL_API(allgather); - HCOL_SAVE_PREV_COLL_API(allgatherv); - HCOL_SAVE_PREV_COLL_API(gatherv); - HCOL_SAVE_PREV_COLL_API(scatterv); - HCOL_SAVE_PREV_COLL_API(alltoall); - HCOL_SAVE_PREV_COLL_API(alltoallv); - - HCOL_SAVE_PREV_COLL_API(ibarrier); - HCOL_SAVE_PREV_COLL_API(ibcast); - HCOL_SAVE_PREV_COLL_API(iallreduce); - HCOL_SAVE_PREV_COLL_API(ireduce); - HCOL_SAVE_PREV_COLL_API(iallgather); - HCOL_SAVE_PREV_COLL_API(iallgatherv); - HCOL_SAVE_PREV_COLL_API(igatherv); - HCOL_SAVE_PREV_COLL_API(ialltoall); - HCOL_SAVE_PREV_COLL_API(ialltoallv); + hcoll_module->super.coll_barrier = hcoll_collectives.coll_barrier ? mca_coll_hcoll_barrier : NULL; + hcoll_module->super.coll_bcast = hcoll_collectives.coll_bcast ? mca_coll_hcoll_bcast : NULL; + hcoll_module->super.coll_allgather = hcoll_collectives.coll_allgather ? mca_coll_hcoll_allgather : NULL; + hcoll_module->super.coll_allgatherv = hcoll_collectives.coll_allgatherv ? mca_coll_hcoll_allgatherv : NULL; + hcoll_module->super.coll_allreduce = hcoll_collectives.coll_allreduce ? mca_coll_hcoll_allreduce : NULL; + hcoll_module->super.coll_alltoall = hcoll_collectives.coll_alltoall ? mca_coll_hcoll_alltoall : NULL; + hcoll_module->super.coll_alltoallv = hcoll_collectives.coll_alltoallv ? mca_coll_hcoll_alltoallv : NULL; + hcoll_module->super.coll_gatherv = hcoll_collectives.coll_gatherv ? mca_coll_hcoll_gatherv : NULL; + hcoll_module->super.coll_scatterv = hcoll_collectives.coll_scatterv ? mca_coll_hcoll_scatterv : NULL; + hcoll_module->super.coll_reduce = hcoll_collectives.coll_reduce ? mca_coll_hcoll_reduce : NULL; + hcoll_module->super.coll_ibarrier = hcoll_collectives.coll_ibarrier ? mca_coll_hcoll_ibarrier : NULL; + hcoll_module->super.coll_ibcast = hcoll_collectives.coll_ibcast ? mca_coll_hcoll_ibcast : NULL; + hcoll_module->super.coll_iallgather = hcoll_collectives.coll_iallgather ? mca_coll_hcoll_iallgather : NULL; +#if HCOLL_API >= HCOLL_VERSION(3, 5) + hcoll_module->super.coll_iallgatherv = hcoll_collectives.coll_iallgatherv ? mca_coll_hcoll_iallgatherv : NULL; +#else + hcoll_module->super.coll_iallgatherv = NULL; +#endif + hcoll_module->super.coll_iallreduce = hcoll_collectives.coll_iallreduce ? mca_coll_hcoll_iallreduce : NULL; +#if HCOLL_API >= HCOLL_VERSION(3, 5) + hcoll_module->super.coll_ireduce = hcoll_collectives.coll_ireduce ? mca_coll_hcoll_ireduce : NULL; +#else + hcoll_module->super.coll_ireduce = NULL; +#endif + hcoll_module->super.coll_gather = /*hcoll_collectives.coll_gather ? mca_coll_hcoll_gather :*/ NULL; + hcoll_module->super.coll_igatherv = hcoll_collectives.coll_igatherv ? mca_coll_hcoll_igatherv : NULL; + hcoll_module->super.coll_ialltoall = /*hcoll_collectives.coll_ialltoall ? mca_coll_hcoll_ialltoall : */ NULL; +#if HCOLL_API >= HCOLL_VERSION(3, 7) + hcoll_module->super.coll_ialltoallv = hcoll_collectives.coll_ialltoallv ? mca_coll_hcoll_ialltoallv : NULL; +#else + hcoll_module->super.coll_ialltoallv = NULL; +#endif +#if HCOLL_API > HCOLL_VERSION(4, 5) + hcoll_module->super.coll_reduce_scatter_block = hcoll_collectives.coll_reduce_scatter_block ? mca_coll_hcoll_reduce_scatter_block : NULL; + hcoll_module->super.coll_reduce_scatter = hcoll_collectives.coll_reduce_scatter ? mca_coll_hcoll_reduce_scatter : NULL; +#endif + + HCOL_INSTALL_COLL_API(comm, hcoll_module, barrier); + HCOL_INSTALL_COLL_API(comm, hcoll_module, bcast); + HCOL_INSTALL_COLL_API(comm, hcoll_module, allreduce); + HCOL_INSTALL_COLL_API(comm, hcoll_module, reduce_scatter_block); + HCOL_INSTALL_COLL_API(comm, hcoll_module, reduce_scatter); + HCOL_INSTALL_COLL_API(comm, hcoll_module, reduce); + HCOL_INSTALL_COLL_API(comm, hcoll_module, allgather); + HCOL_INSTALL_COLL_API(comm, hcoll_module, allgatherv); + HCOL_INSTALL_COLL_API(comm, hcoll_module, gatherv); + HCOL_INSTALL_COLL_API(comm, hcoll_module, scatterv); + HCOL_INSTALL_COLL_API(comm, hcoll_module, alltoall); + HCOL_INSTALL_COLL_API(comm, hcoll_module, alltoallv); + + HCOL_INSTALL_COLL_API(comm, hcoll_module, ibarrier); + HCOL_INSTALL_COLL_API(comm, hcoll_module, ibcast); + HCOL_INSTALL_COLL_API(comm, hcoll_module, iallreduce); + HCOL_INSTALL_COLL_API(comm, hcoll_module, ireduce); + HCOL_INSTALL_COLL_API(comm, hcoll_module, iallgather); + HCOL_INSTALL_COLL_API(comm, hcoll_module, iallgatherv); + HCOL_INSTALL_COLL_API(comm, hcoll_module, igatherv); + HCOL_INSTALL_COLL_API(comm, hcoll_module, ialltoall); + HCOL_INSTALL_COLL_API(comm, hcoll_module, ialltoallv); /* These collectives are not yet part of hcoll, so don't retain them on hcoll module - HCOL_SAVE_PREV_COLL_API(reduce_scatter); - HCOL_SAVE_PREV_COLL_API(gather); - HCOL_SAVE_PREV_COLL_API(reduce); - HCOL_SAVE_PREV_COLL_API(allgatherv); - HCOL_SAVE_PREV_COLL_API(alltoallw); + HCOL_INSTALL_COLL_API(comm, hcoll_module, reduce_scatter); + HCOL_INSTALL_COLL_API(comm, hcoll_module, gather); + HCOL_INSTALL_COLL_API(comm, hcoll_module, reduce); + HCOL_INSTALL_COLL_API(comm, hcoll_module, allgatherv); + HCOL_INSTALL_COLL_API(comm, hcoll_module, alltoallw); */ return OMPI_SUCCESS; } @@ -251,6 +271,45 @@ static int mca_coll_hcoll_module_enable(mca_coll_base_module_t *module, return OMPI_SUCCESS; } +static int mca_coll_hcoll_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t *)module; + + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, barrier); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, bcast); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, allreduce); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, reduce_scatter_block); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, reduce_scatter); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, reduce); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, allgather); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, allgatherv); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, gatherv); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, scatterv); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, alltoall); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, alltoallv); + + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, ibarrier); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, ibcast); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, iallreduce); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, ireduce); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, iallgather); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, iallgatherv); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, igatherv); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, ialltoall); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, ialltoallv); + + /* + These collectives are not yet part of hcoll, so + don't retain them on hcoll module + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, reduce_scatter); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, gather); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, reduce); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, allgatherv); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, alltoallw); + */ + return OMPI_SUCCESS; +} OBJ_CLASS_INSTANCE(mca_coll_hcoll_dtype_t, opal_free_list_item_t, @@ -395,44 +454,8 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority) } hcoll_module->super.coll_module_enable = mca_coll_hcoll_module_enable; - hcoll_module->super.coll_barrier = hcoll_collectives.coll_barrier ? mca_coll_hcoll_barrier : NULL; - hcoll_module->super.coll_bcast = hcoll_collectives.coll_bcast ? mca_coll_hcoll_bcast : NULL; - hcoll_module->super.coll_allgather = hcoll_collectives.coll_allgather ? mca_coll_hcoll_allgather : NULL; - hcoll_module->super.coll_allgatherv = hcoll_collectives.coll_allgatherv ? mca_coll_hcoll_allgatherv : NULL; - hcoll_module->super.coll_allreduce = hcoll_collectives.coll_allreduce ? mca_coll_hcoll_allreduce : NULL; - hcoll_module->super.coll_alltoall = hcoll_collectives.coll_alltoall ? mca_coll_hcoll_alltoall : NULL; - hcoll_module->super.coll_alltoallv = hcoll_collectives.coll_alltoallv ? mca_coll_hcoll_alltoallv : NULL; - hcoll_module->super.coll_gatherv = hcoll_collectives.coll_gatherv ? mca_coll_hcoll_gatherv : NULL; - hcoll_module->super.coll_scatterv = hcoll_collectives.coll_scatterv ? mca_coll_hcoll_scatterv : NULL; - hcoll_module->super.coll_reduce = hcoll_collectives.coll_reduce ? mca_coll_hcoll_reduce : NULL; - hcoll_module->super.coll_ibarrier = hcoll_collectives.coll_ibarrier ? mca_coll_hcoll_ibarrier : NULL; - hcoll_module->super.coll_ibcast = hcoll_collectives.coll_ibcast ? mca_coll_hcoll_ibcast : NULL; - hcoll_module->super.coll_iallgather = hcoll_collectives.coll_iallgather ? mca_coll_hcoll_iallgather : NULL; -#if HCOLL_API >= HCOLL_VERSION(3,5) - hcoll_module->super.coll_iallgatherv = hcoll_collectives.coll_iallgatherv ? mca_coll_hcoll_iallgatherv : NULL; -#else - hcoll_module->super.coll_iallgatherv = NULL; -#endif - hcoll_module->super.coll_iallreduce = hcoll_collectives.coll_iallreduce ? mca_coll_hcoll_iallreduce : NULL; -#if HCOLL_API >= HCOLL_VERSION(3,5) - hcoll_module->super.coll_ireduce = hcoll_collectives.coll_ireduce ? mca_coll_hcoll_ireduce : NULL; -#else - hcoll_module->super.coll_ireduce = NULL; -#endif - hcoll_module->super.coll_gather = /*hcoll_collectives.coll_gather ? mca_coll_hcoll_gather :*/ NULL; - hcoll_module->super.coll_igatherv = hcoll_collectives.coll_igatherv ? mca_coll_hcoll_igatherv : NULL; - hcoll_module->super.coll_ialltoall = /*hcoll_collectives.coll_ialltoall ? mca_coll_hcoll_ialltoall : */ NULL; -#if HCOLL_API >= HCOLL_VERSION(3,7) - hcoll_module->super.coll_ialltoallv = hcoll_collectives.coll_ialltoallv ? mca_coll_hcoll_ialltoallv : NULL; -#else - hcoll_module->super.coll_ialltoallv = NULL; -#endif -#if HCOLL_API > HCOLL_VERSION(4,5) - hcoll_module->super.coll_reduce_scatter_block = hcoll_collectives.coll_reduce_scatter_block ? - mca_coll_hcoll_reduce_scatter_block : NULL; - hcoll_module->super.coll_reduce_scatter = hcoll_collectives.coll_reduce_scatter ? - mca_coll_hcoll_reduce_scatter : NULL; -#endif + hcoll_module->super.coll_module_disable = mca_coll_hcoll_module_disable; + *priority = cm->hcoll_priority; module = &hcoll_module->super; diff --git a/ompi/mca/coll/inter/Makefile.am b/ompi/mca/coll/inter/Makefile.am index d9c691cf458..ea9188cffd4 100644 --- a/ompi/mca/coll/inter/Makefile.am +++ b/ompi/mca/coll/inter/Makefile.am @@ -11,6 +11,7 @@ # All rights reserved. # Copyright (c) 2010 Cisco Systems, Inc. All rights reserved. # Copyright (c) 2017 IBM Corporation. All rights reserved. +# Copyright (c) 2024 NVIDIA Corporation. All rights reserved. # $COPYRIGHT$ # # Additional copyrights may follow @@ -43,7 +44,7 @@ libmca_coll_inter_la_LDFLAGS = -module -avoid-version sources = \ coll_inter.h \ - coll_inter.c \ + coll_inter_module.c \ coll_inter_allreduce.c \ coll_inter_allgather.c \ coll_inter_allgatherv.c \ diff --git a/ompi/mca/coll/inter/coll_inter.h b/ompi/mca/coll/inter/coll_inter.h index 09bb78c5c79..17dc055451c 100644 --- a/ompi/mca/coll/inter/coll_inter.h +++ b/ompi/mca/coll/inter/coll_inter.h @@ -13,6 +13,7 @@ * Copyright (c) 2008 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2015 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -50,9 +51,6 @@ int mca_coll_inter_init_query(bool allow_inter_user_threads, mca_coll_base_module_t * mca_coll_inter_comm_query(struct ompi_communicator_t *comm, int *priority); -int mca_coll_inter_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - int mca_coll_inter_allgather_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, diff --git a/ompi/mca/coll/inter/coll_inter_allgather.c b/ompi/mca/coll/inter/coll_inter_allgather.c index fe867cda06a..59cd19ff75f 100644 --- a/ompi/mca/coll/inter/coll_inter_allgather.c +++ b/ompi/mca/coll/inter/coll_inter_allgather.c @@ -13,6 +13,7 @@ * Copyright (c) 2015-2017 Research Organization for Information Science * and Technology (RIST). All rights reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -61,23 +62,23 @@ mca_coll_inter_allgather_inter(const void *sbuf, int scount, /* Perform the gather locally at the root */ if ( scount > 0 ) { span = opal_datatype_span(&sdtype->super, (int64_t)scount*(int64_t)size, &gap); - ptmp_free = (char*)malloc(span); - if (NULL == ptmp_free) { - return OMPI_ERR_OUT_OF_RESOURCE; - } + ptmp_free = (char*)malloc(span); + if (NULL == ptmp_free) { + return OMPI_ERR_OUT_OF_RESOURCE; + } ptmp = ptmp_free - gap; - err = comm->c_local_comm->c_coll->coll_gather(sbuf, scount, sdtype, + err = comm->c_local_comm->c_coll->coll_gather(sbuf, scount, sdtype, ptmp, scount, sdtype, 0, comm->c_local_comm, comm->c_local_comm->c_coll->coll_gather_module); - if (OMPI_SUCCESS != err) { - goto exit; - } + if (OMPI_SUCCESS != err) { + goto exit; + } } if (rank == root) { - /* Do a send-recv between the two root procs. to avoid deadlock */ + /* Do a send-recv between the two root procs. to avoid deadlock */ err = ompi_coll_base_sendrecv_actual(ptmp, scount*(size_t)size, sdtype, 0, MCA_COLL_BASE_TAG_ALLGATHER, rbuf, rcount*(size_t)rsize, rdtype, 0, diff --git a/ompi/mca/coll/inter/coll_inter_component.c b/ompi/mca/coll/inter/coll_inter_component.c index 39ff40c9cfb..76f7340b1be 100644 --- a/ompi/mca/coll/inter/coll_inter_component.c +++ b/ompi/mca/coll/inter/coll_inter_component.c @@ -14,6 +14,7 @@ * Copyright (c) 2008 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2015 Los Alamos National Security, LLC. All rights * reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -113,14 +114,7 @@ mca_coll_inter_module_construct(mca_coll_inter_module_t *module) module->inter_comm = NULL; } -static void -mca_coll_inter_module_destruct(mca_coll_inter_module_t *module) -{ - -} - - OBJ_CLASS_INSTANCE(mca_coll_inter_module_t, mca_coll_base_module_t, mca_coll_inter_module_construct, - mca_coll_inter_module_destruct); + NULL); diff --git a/ompi/mca/coll/inter/coll_inter.c b/ompi/mca/coll/inter/coll_inter_module.c similarity index 63% rename from ompi/mca/coll/inter/coll_inter.c rename to ompi/mca/coll/inter/coll_inter_module.c index d75994e396a..2dd35c10fa2 100644 --- a/ompi/mca/coll/inter/coll_inter.c +++ b/ompi/mca/coll/inter/coll_inter_module.c @@ -13,6 +13,7 @@ * Copyright (c) 2008 Sun Microsystems, Inc. All rights reserved. * Copyright (c) 2013 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2017 IBM Corporation. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -37,7 +38,6 @@ #if 0 -static void mca_coll_inter_dump_struct ( struct mca_coll_base_comm_t *c); static const mca_coll_base_module_1_0_0_t inter = { @@ -68,6 +68,12 @@ static const mca_coll_base_module_1_0_0_t inter = { }; #endif +static int +mca_coll_inter_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); +static int +mca_coll_inter_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); /* * Initial query function that is invoked during MPI_INIT, allowing @@ -101,22 +107,23 @@ mca_coll_inter_comm_query(struct ompi_communicator_t *comm, int *priority) * than or equal to 0, then the module is unavailable. */ *priority = mca_coll_inter_priority_param; if (0 >= mca_coll_inter_priority_param) { - return NULL; + return NULL; } size = ompi_comm_size(comm); rsize = ompi_comm_remote_size(comm); if ( size < mca_coll_inter_crossover && rsize < mca_coll_inter_crossover) { - return NULL; + return NULL; } inter_module = OBJ_NEW(mca_coll_inter_module_t); if (NULL == inter_module) { - return NULL; + return NULL; } inter_module->super.coll_module_enable = mca_coll_inter_module_enable; + inter_module->super.coll_module_disable = mca_coll_inter_module_disable; inter_module->super.coll_allgather = mca_coll_inter_allgather_inter; inter_module->super.coll_allgatherv = mca_coll_inter_allgatherv_inter; @@ -139,11 +146,24 @@ mca_coll_inter_comm_query(struct ompi_communicator_t *comm, int *priority) return &(inter_module->super); } - +#define INTER_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "inter"); \ + } while (0) + +#define INTER_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "inter"); \ + } \ + } while (0) /* * Init module on the communicator */ -int +static int mca_coll_inter_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm) { @@ -151,27 +171,35 @@ mca_coll_inter_module_enable(mca_coll_base_module_t *module, inter_module->inter_comm = comm; -#if 0 - if ( mca_coll_inter_verbose_param ) { - mca_coll_inter_dump_struct (data); - } -#endif - + INTER_INSTALL_COLL_API(comm, inter_module, allgather); + INTER_INSTALL_COLL_API(comm, inter_module, allgatherv); + INTER_INSTALL_COLL_API(comm, inter_module, allreduce); + INTER_INSTALL_COLL_API(comm, inter_module, bcast); + INTER_INSTALL_COLL_API(comm, inter_module, gather); + INTER_INSTALL_COLL_API(comm, inter_module, gatherv); + INTER_INSTALL_COLL_API(comm, inter_module, reduce); + INTER_INSTALL_COLL_API(comm, inter_module, scatter); + INTER_INSTALL_COLL_API(comm, inter_module, scatterv); return OMPI_SUCCESS; } - -#if 0 -static void mca_coll_inter_dump_struct ( struct mca_coll_base_comm_t *c) +static int +mca_coll_inter_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) { - int rank; - - rank = ompi_comm_rank ( c->inter_comm ); + mca_coll_inter_module_t *inter_module = (mca_coll_inter_module_t*) module; - printf("%d: Dump of inter-struct for comm %s cid %u\n", - rank, c->inter_comm->c_name, c->inter_comm->c_contextid); + INTER_UNINSTALL_COLL_API(comm, inter_module, allgather); + INTER_UNINSTALL_COLL_API(comm, inter_module, allgatherv); + INTER_UNINSTALL_COLL_API(comm, inter_module, allreduce); + INTER_UNINSTALL_COLL_API(comm, inter_module, bcast); + INTER_UNINSTALL_COLL_API(comm, inter_module, gather); + INTER_UNINSTALL_COLL_API(comm, inter_module, gatherv); + INTER_UNINSTALL_COLL_API(comm, inter_module, reduce); + INTER_UNINSTALL_COLL_API(comm, inter_module, scatter); + INTER_UNINSTALL_COLL_API(comm, inter_module, scatterv); + inter_module->inter_comm = NULL; - return; + return OMPI_SUCCESS; } -#endif diff --git a/ompi/mca/coll/libnbc/coll_libnbc_component.c b/ompi/mca/coll/libnbc/coll_libnbc_component.c index 3b9662ea682..803b3ae46b1 100644 --- a/ompi/mca/coll/libnbc/coll_libnbc_component.c +++ b/ompi/mca/coll/libnbc/coll_libnbc_component.c @@ -19,6 +19,7 @@ * Copyright (c) 2017 Ian Bradley Morgan and Anthony Skjellum. All * rights reserved. * Copyright (c) 2018 FUJITSU LIMITED. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -30,6 +31,7 @@ #include "coll_libnbc.h" #include "nbc_internal.h" +#include "ompi/mca/coll/base/base.h" #include "mpi.h" #include "ompi/mca/coll/coll.h" @@ -106,6 +108,7 @@ static int libnbc_register(void); static int libnbc_init_query(bool, bool); static mca_coll_base_module_t *libnbc_comm_query(struct ompi_communicator_t *, int *); static int libnbc_module_enable(mca_coll_base_module_t *, struct ompi_communicator_t *); +static int libnbc_module_disable(mca_coll_base_module_t *, struct ompi_communicator_t *); /* * Instantiate the public struct with all of our public information @@ -313,6 +316,7 @@ libnbc_comm_query(struct ompi_communicator_t *comm, *priority = libnbc_priority; module->super.coll_module_enable = libnbc_module_enable; + module->super.coll_module_disable = libnbc_module_disable; if (OMPI_COMM_IS_INTER(comm)) { module->super.coll_iallgather = ompi_coll_libnbc_iallgather_inter; module->super.coll_iallgatherv = ompi_coll_libnbc_iallgatherv_inter; @@ -407,7 +411,23 @@ libnbc_comm_query(struct ompi_communicator_t *comm, return &(module->super); } - +#define NBC_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__module->super.coll_##__api) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "libnbc"); \ + } \ + } while (0) + +#define NBC_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "libnbc"); \ + } \ + } while (0) /* * Init module on the communicator */ @@ -415,10 +435,120 @@ static int libnbc_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm) { - /* All done */ + ompi_coll_libnbc_module_t *nbc_module = (ompi_coll_libnbc_module_t*)module; + + NBC_INSTALL_COLL_API(comm, nbc_module, iallgather); + NBC_INSTALL_COLL_API(comm, nbc_module, iallgatherv); + NBC_INSTALL_COLL_API(comm, nbc_module, iallreduce); + NBC_INSTALL_COLL_API(comm, nbc_module, ialltoall); + NBC_INSTALL_COLL_API(comm, nbc_module, ialltoallv); + NBC_INSTALL_COLL_API(comm, nbc_module, ialltoallw); + NBC_INSTALL_COLL_API(comm, nbc_module, ibarrier); + NBC_INSTALL_COLL_API(comm, nbc_module, ibcast); + NBC_INSTALL_COLL_API(comm, nbc_module, igather); + NBC_INSTALL_COLL_API(comm, nbc_module, igatherv); + NBC_INSTALL_COLL_API(comm, nbc_module, ireduce); + NBC_INSTALL_COLL_API(comm, nbc_module, ireduce_scatter); + NBC_INSTALL_COLL_API(comm, nbc_module, ireduce_scatter_block); + NBC_INSTALL_COLL_API(comm, nbc_module, iscatter); + NBC_INSTALL_COLL_API(comm, nbc_module, iscatterv); + + NBC_INSTALL_COLL_API(comm, nbc_module, allgather_init); + NBC_INSTALL_COLL_API(comm, nbc_module, allgatherv_init); + NBC_INSTALL_COLL_API(comm, nbc_module, allreduce_init); + NBC_INSTALL_COLL_API(comm, nbc_module, alltoall_init); + NBC_INSTALL_COLL_API(comm, nbc_module, alltoallv_init); + NBC_INSTALL_COLL_API(comm, nbc_module, alltoallw_init); + NBC_INSTALL_COLL_API(comm, nbc_module, barrier_init); + NBC_INSTALL_COLL_API(comm, nbc_module, bcast_init); + NBC_INSTALL_COLL_API(comm, nbc_module, gather_init); + NBC_INSTALL_COLL_API(comm, nbc_module, gatherv_init); + NBC_INSTALL_COLL_API(comm, nbc_module, reduce_init); + NBC_INSTALL_COLL_API(comm, nbc_module, reduce_scatter_init); + NBC_INSTALL_COLL_API(comm, nbc_module, reduce_scatter_block_init); + NBC_INSTALL_COLL_API(comm, nbc_module, scatter_init); + NBC_INSTALL_COLL_API(comm, nbc_module, scatterv_init); + + if (!OMPI_COMM_IS_INTER(comm)) { + NBC_INSTALL_COLL_API(comm, nbc_module, exscan_init); + NBC_INSTALL_COLL_API(comm, nbc_module, iexscan); + NBC_INSTALL_COLL_API(comm, nbc_module, scan_init); + NBC_INSTALL_COLL_API(comm, nbc_module, iscan); + + NBC_INSTALL_COLL_API(comm, nbc_module, ineighbor_allgather); + NBC_INSTALL_COLL_API(comm, nbc_module, ineighbor_allgatherv); + NBC_INSTALL_COLL_API(comm, nbc_module, ineighbor_alltoall); + NBC_INSTALL_COLL_API(comm, nbc_module, ineighbor_alltoallv); + NBC_INSTALL_COLL_API(comm, nbc_module, ineighbor_alltoallw); + + NBC_INSTALL_COLL_API(comm, nbc_module, neighbor_allgather_init); + NBC_INSTALL_COLL_API(comm, nbc_module, neighbor_allgatherv_init); + NBC_INSTALL_COLL_API(comm, nbc_module, neighbor_alltoall_init); + NBC_INSTALL_COLL_API(comm, nbc_module, neighbor_alltoallv_init); + NBC_INSTALL_COLL_API(comm, nbc_module, neighbor_alltoallw_init); + } /* All done */ + return OMPI_SUCCESS; } +static int +libnbc_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + ompi_coll_libnbc_module_t *nbc_module = (ompi_coll_libnbc_module_t*)module; + + NBC_UNINSTALL_COLL_API(comm, nbc_module, iallgather); + NBC_UNINSTALL_COLL_API(comm, nbc_module, iallgatherv); + NBC_UNINSTALL_COLL_API(comm, nbc_module, iallreduce); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ialltoall); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ialltoallv); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ialltoallw); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ibarrier); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ibcast); + NBC_UNINSTALL_COLL_API(comm, nbc_module, igather); + NBC_UNINSTALL_COLL_API(comm, nbc_module, igatherv); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ireduce); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ireduce_scatter); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ireduce_scatter_block); + NBC_UNINSTALL_COLL_API(comm, nbc_module, iscatter); + NBC_UNINSTALL_COLL_API(comm, nbc_module, iscatterv); + + NBC_UNINSTALL_COLL_API(comm, nbc_module, allgather_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, allgatherv_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, allreduce_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, alltoall_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, alltoallv_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, alltoallw_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, barrier_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, bcast_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, gather_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, gatherv_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, reduce_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, reduce_scatter_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, reduce_scatter_block_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, scatter_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, scatterv_init); + + if (!OMPI_COMM_IS_INTER(comm)) { + NBC_UNINSTALL_COLL_API(comm, nbc_module, exscan_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, iexscan); + NBC_UNINSTALL_COLL_API(comm, nbc_module, scan_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, iscan); + + NBC_UNINSTALL_COLL_API(comm, nbc_module, ineighbor_allgather); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ineighbor_allgatherv); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ineighbor_alltoall); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ineighbor_alltoallv); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ineighbor_alltoallw); + + NBC_UNINSTALL_COLL_API(comm, nbc_module, neighbor_allgather_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, neighbor_allgatherv_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, neighbor_alltoall_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, neighbor_alltoallv_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, neighbor_alltoallw_init); + } /* All done */ + return OMPI_SUCCESS; +} int ompi_coll_libnbc_progress(void) diff --git a/ompi/mca/coll/monitoring/coll_monitoring.h b/ompi/mca/coll/monitoring/coll_monitoring.h index 3f741970907..49d9d10b865 100644 --- a/ompi/mca/coll/monitoring/coll_monitoring.h +++ b/ompi/mca/coll/monitoring/coll_monitoring.h @@ -4,6 +4,7 @@ * and Technology (RIST). All rights reserved. * Copyright (c) 2017 Amazon.com, Inc. or its affiliates. All Rights * reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -36,7 +37,6 @@ struct mca_coll_monitoring_module_t { mca_coll_base_module_t super; mca_coll_base_comm_coll_t real; mca_monitoring_coll_data_t*data; - opal_atomic_int32_t is_initialized; }; typedef struct mca_coll_monitoring_module_t mca_coll_monitoring_module_t; OMPI_DECLSPEC OBJ_CLASS_DECLARATION(mca_coll_monitoring_module_t); diff --git a/ompi/mca/coll/monitoring/coll_monitoring_component.c b/ompi/mca/coll/monitoring/coll_monitoring_component.c index b23c3308ca2..cc1f56293d6 100644 --- a/ompi/mca/coll/monitoring/coll_monitoring_component.c +++ b/ompi/mca/coll/monitoring/coll_monitoring_component.c @@ -7,6 +7,7 @@ * reserved. * Copyright (c) 2019 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -21,49 +22,43 @@ #include "ompi/mca/coll/coll.h" #include "opal/mca/base/mca_base_component_repository.h" -#define MONITORING_SAVE_PREV_COLL_API(__module, __comm, __api) \ - do { \ - if( NULL != __comm->c_coll->coll_ ## __api ## _module ) { \ - __module->real.coll_ ## __api = __comm->c_coll->coll_ ## __api; \ - __module->real.coll_ ## __api ## _module = __comm->c_coll->coll_ ## __api ## _module; \ - OBJ_RETAIN(__module->real.coll_ ## __api ## _module); \ - } else { \ - /* If no function previously provided, do not monitor */ \ - __module->super.coll_ ## __api = NULL; \ - OPAL_MONITORING_PRINT_WARN("COMM \"%s\": No monitoring available for " \ - "coll_" # __api, __comm->c_name); \ - } \ - if( NULL != __comm->c_coll->coll_i ## __api ## _module ) { \ - __module->real.coll_i ## __api = __comm->c_coll->coll_i ## __api; \ - __module->real.coll_i ## __api ## _module = __comm->c_coll->coll_i ## __api ## _module; \ - OBJ_RETAIN(__module->real.coll_i ## __api ## _module); \ - } else { \ - /* If no function previously provided, do not monitor */ \ - __module->super.coll_i ## __api = NULL; \ - OPAL_MONITORING_PRINT_WARN("COMM \"%s\": No monitoring available for " \ - "coll_i" # __api, __comm->c_name); \ - } \ - } while(0) - -#define MONITORING_RELEASE_PREV_COLL_API(__module, __comm, __api) \ - do { \ - if( NULL != __module->real.coll_ ## __api ## _module ) { \ - if( NULL != __module->real.coll_ ## __api ## _module->coll_module_disable ) { \ - __module->real.coll_ ## __api ## _module->coll_module_disable(__module->real.coll_ ## __api ## _module, __comm); \ - } \ - OBJ_RELEASE(__module->real.coll_ ## __api ## _module); \ - __module->real.coll_ ## __api = NULL; \ - __module->real.coll_ ## __api ## _module = NULL; \ - } \ - if( NULL != __module->real.coll_i ## __api ## _module ) { \ - if( NULL != __module->real.coll_i ## __api ## _module->coll_module_disable ) { \ - __module->real.coll_i ## __api ## _module->coll_module_disable(__module->real.coll_i ## __api ## _module, __comm); \ - } \ - OBJ_RELEASE(__module->real.coll_i ## __api ## _module); \ - __module->real.coll_i ## __api = NULL; \ - __module->real.coll_i ## __api ## _module = NULL; \ - } \ - } while(0) +#define MONITORING_INSTALL_COLL_API(__module, __comm, __api) \ + do \ + { \ + if ((NULL != __comm->c_coll->coll_##__api##_module) && \ + (NULL != __module->super.coll_##__api)) \ + { \ + /* save the current selected collective */ \ + MCA_COLL_SAVE_API(__comm, __api, __module->real.coll_##__api, __module->real.coll_##__api##_module, "monitoring"); \ + /* install our own */ \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "monitoring"); \ + } \ + if ((NULL != __comm->c_coll->coll_i##__api##_module) && \ + (NULL != __module->super.coll_i##__api)) \ + { \ + /* save the current selected collective */ \ + MCA_COLL_SAVE_API(__comm, i##__api, __module->real.coll_i##__api, __module->real.coll_i##__api##_module, "monitoring"); \ + /* install our own */ \ + MCA_COLL_INSTALL_API(__comm, i##__api, __module->super.coll_i##__api, &__module->super, "monitoring"); \ + } \ + } while (0) + +#define MONITORING_UNINSTALL_COLL_API(__module, __comm, __api) \ + do \ + { \ + if (&__module->super == __comm->c_coll->coll_##__api##_module) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->real.coll_##__api, __module->real.coll_##__api##_module, "monitoring"); \ + __module->real.coll_##__api = NULL; \ + __module->real.coll_##__api##_module = NULL; \ + } \ + if (&__module->super == __comm->c_coll->coll_i##__api##_module) \ + { \ + MCA_COLL_INSTALL_API(__comm, i##__api, __module->real.coll_i##__api, __module->real.coll_i##__api##_module, "monitoring"); \ + __module->real.coll_i##__api = NULL; \ + __module->real.coll_i##__api##_module = NULL; \ + } \ + } while (0) #define MONITORING_SET_FULL_PREV_COLL_API(m, c, operation) \ do { \ @@ -91,11 +86,11 @@ operation(m, c, neighbor_alltoallw); \ } while(0) -#define MONITORING_SAVE_FULL_PREV_COLL_API(m, c) \ - MONITORING_SET_FULL_PREV_COLL_API((m), (c), MONITORING_SAVE_PREV_COLL_API) +#define MONITORING_SAVE_FULL_PREV_COLL_API(m, c) \ + MONITORING_SET_FULL_PREV_COLL_API((m), (c), MONITORING_INSTALL_COLL_API) -#define MONITORING_RELEASE_FULL_PREV_COLL_API(m, c) \ - MONITORING_SET_FULL_PREV_COLL_API((m), (c), MONITORING_RELEASE_PREV_COLL_API) +#define MONITORING_RELEASE_FULL_PREV_COLL_API(m, c) \ + MONITORING_SET_FULL_PREV_COLL_API((m), (c), MONITORING_UNINSTALL_COLL_API) static int mca_coll_monitoring_component_open(void) { @@ -125,11 +120,11 @@ static int mca_coll_monitoring_module_enable(mca_coll_base_module_t*module, struct ompi_communicator_t*comm) { mca_coll_monitoring_module_t*monitoring_module = (mca_coll_monitoring_module_t*) module; - if( 1 == opal_atomic_add_fetch_32(&monitoring_module->is_initialized, 1) ) { - MONITORING_SAVE_FULL_PREV_COLL_API(monitoring_module, comm); - monitoring_module->data = mca_common_monitoring_coll_new(comm); - OPAL_MONITORING_PRINT_INFO("coll_module_enabled"); - } + + MONITORING_SAVE_FULL_PREV_COLL_API(monitoring_module, comm); + monitoring_module->data = mca_common_monitoring_coll_new(comm); + OPAL_MONITORING_PRINT_INFO("coll_module_enabled"); + return OMPI_SUCCESS; } @@ -137,12 +132,12 @@ static int mca_coll_monitoring_module_disable(mca_coll_base_module_t*module, struct ompi_communicator_t*comm) { mca_coll_monitoring_module_t*monitoring_module = (mca_coll_monitoring_module_t*) module; - if( 0 == opal_atomic_sub_fetch_32(&monitoring_module->is_initialized, 1) ) { - MONITORING_RELEASE_FULL_PREV_COLL_API(monitoring_module, comm); - mca_common_monitoring_coll_release(monitoring_module->data); - monitoring_module->data = NULL; - OPAL_MONITORING_PRINT_INFO("coll_module_disabled"); - } + + MONITORING_RELEASE_FULL_PREV_COLL_API(monitoring_module, comm); + mca_common_monitoring_coll_release(monitoring_module->data); + monitoring_module->data = NULL; + OPAL_MONITORING_PRINT_INFO("coll_module_disabled"); + return OMPI_SUCCESS; } @@ -208,9 +203,6 @@ mca_coll_monitoring_component_query(struct ompi_communicator_t*comm, int*priorit monitoring_module->super.coll_ineighbor_alltoallv = mca_coll_monitoring_ineighbor_alltoallv; monitoring_module->super.coll_ineighbor_alltoallw = mca_coll_monitoring_ineighbor_alltoallw; - /* Initialization flag */ - monitoring_module->is_initialized = 0; - *priority = mca_coll_monitoring_component.priority; return &(monitoring_module->super); diff --git a/ompi/mca/coll/portals4/coll_portals4_component.c b/ompi/mca/coll/portals4/coll_portals4_component.c index 6b0a8573371..11b476ceb20 100644 --- a/ompi/mca/coll/portals4/coll_portals4_component.c +++ b/ompi/mca/coll/portals4/coll_portals4_component.c @@ -14,6 +14,7 @@ * Copyright (c) 2015 Los Alamos National Security, LLC. All rights * reserved. * Copyright (c) 2015 Bull SAS. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -127,20 +128,37 @@ ptl_datatype_t ompi_coll_portals4_atomic_datatype [OMPI_DATATYPE_MPI_MAX_PREDEFI }; - -#define PORTALS4_SAVE_PREV_COLL_API(__module, __comm, __api) \ - do { \ - __module->previous_ ## __api = __comm->c_coll->coll_ ## __api; \ - __module->previous_ ## __api ## _module = __comm->c_coll->coll_ ## __api ## _module; \ - if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \ - opal_output_verbose(1, ompi_coll_base_framework.framework_output, \ - "(%d/%s): no underlying " # __api"; disqualifying myself", \ - ompi_comm_get_local_cid(__comm), __comm->c_name); \ - return OMPI_ERROR; \ - } \ - OBJ_RETAIN(__module->previous_ ## __api ## _module); \ - } while(0) - +#define PORTALS4_INSTALL_COLL_API(__module, __comm, __api) \ + do \ + { \ + if (!comm->c_coll->coll_##__api || !comm->c_coll->coll_##__api##_module) \ + { \ + opal_output_verbose(1, ompi_coll_base_framework.framework_output, \ + "(%d/%s): no underlying " #__api "; disqualifying myself", \ + __comm->c_contextid, __comm->c_name); \ + __module->previous_##__api = NULL; \ + __module->previous_##__api##_module = NULL; \ + } \ + else \ + { \ + /* save the current selected collective */ \ + MCA_COLL_SAVE_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "portals"); \ + /* install our own */ \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll##__api, &__module->super, "portals"); \ + } \ + } while (0) + +#define PORTALS4_UNINSTALL_COLL_API(__module, __comm, __api) \ + do \ + { \ + if ((&__module->super == comm->c_coll->coll_##__api##_module) && \ + (NULL != __module->previous_##__api##_module)) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "portals"); \ + __module->previous_##__api = NULL; \ + __module->previous_##__api##_module = NULL; \ + } \ + } while (0) const char *mca_coll_portals4_component_version_string = "Open MPI Portals 4 collective MCA component version " OMPI_VERSION; @@ -158,6 +176,8 @@ static mca_coll_base_module_t* portals4_comm_query(struct ompi_communicator_t *c int *priority); static int portals4_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm); +static int portals4_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); static int portals4_progress(void); @@ -618,6 +638,7 @@ portals4_comm_query(struct ompi_communicator_t *comm, *priority = mca_coll_portals4_priority; portals4_module->coll_count = 0; portals4_module->super.coll_module_enable = portals4_module_enable; + portals4_module->super.coll_module_disable = portals4_module_disable; portals4_module->super.coll_barrier = ompi_coll_portals4_barrier_intra; portals4_module->super.coll_ibarrier = ompi_coll_portals4_ibarrier_intra; @@ -653,14 +674,45 @@ portals4_module_enable(mca_coll_base_module_t *module, { mca_coll_portals4_module_t *portals4_module = (mca_coll_portals4_module_t*) module; - PORTALS4_SAVE_PREV_COLL_API(portals4_module, comm, allreduce); - PORTALS4_SAVE_PREV_COLL_API(portals4_module, comm, iallreduce); - PORTALS4_SAVE_PREV_COLL_API(portals4_module, comm, reduce); - PORTALS4_SAVE_PREV_COLL_API(portals4_module, comm, ireduce); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, iallreduce); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, allreduce); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, ireduce); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, reduce); + + PORTALS4_INSTALL_COLL_API(portals4_module, comm, barrier); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, ibarrier); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, gather); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, igather); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, scatter); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, iscatter); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, bcast); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, ibcast); return OMPI_SUCCESS; } +static int +portals4_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_portals4_module_t *portals4_module = (mca_coll_portals4_module_t *)module; + + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, allreduce); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, iallreduce); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, reduce); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, ireduce); + + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, barrier); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, ibarrier); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, gather); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, igather); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, scatter); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, iscatter); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, bcast); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, ibcast); + + return OMPI_SUCCESS; +} #if OPAL_ENABLE_DEBUG /* These string maps are only used for debugging output. * They will be compiled-out when OPAL is configured diff --git a/ompi/mca/coll/self/coll_self.h b/ompi/mca/coll/self/coll_self.h index 1dad4536d67..1585e3bc8a1 100644 --- a/ompi/mca/coll/self/coll_self.h +++ b/ompi/mca/coll/self/coll_self.h @@ -12,6 +12,7 @@ * Copyright (c) 2008 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2015 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -50,9 +51,6 @@ int mca_coll_self_init_query(bool enable_progress_threads, mca_coll_base_module_t * mca_coll_self_comm_query(struct ompi_communicator_t *comm, int *priority); -int mca_coll_self_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - int mca_coll_self_allgather_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, diff --git a/ompi/mca/coll/self/coll_self_module.c b/ompi/mca/coll/self/coll_self_module.c index d782b998165..ec706b52613 100644 --- a/ompi/mca/coll/self/coll_self_module.c +++ b/ompi/mca/coll/self/coll_self_module.c @@ -10,6 +10,7 @@ * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. * Copyright (c) 2017 IBM Corporation. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -29,6 +30,12 @@ #include "ompi/mca/coll/base/coll_base_functions.h" #include "coll_self.h" +static int +mca_coll_self_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); +static int +mca_coll_self_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); /* * Initial query function that is invoked during MPI_INIT, allowing @@ -63,22 +70,7 @@ mca_coll_self_comm_query(struct ompi_communicator_t *comm, if (NULL == module) return NULL; module->super.coll_module_enable = mca_coll_self_module_enable; - module->super.coll_allgather = mca_coll_self_allgather_intra; - module->super.coll_allgatherv = mca_coll_self_allgatherv_intra; - module->super.coll_allreduce = mca_coll_self_allreduce_intra; - module->super.coll_alltoall = mca_coll_self_alltoall_intra; - module->super.coll_alltoallv = mca_coll_self_alltoallv_intra; - module->super.coll_alltoallw = mca_coll_self_alltoallw_intra; - module->super.coll_barrier = mca_coll_self_barrier_intra; - module->super.coll_bcast = mca_coll_self_bcast_intra; - module->super.coll_exscan = mca_coll_self_exscan_intra; - module->super.coll_gather = mca_coll_self_gather_intra; - module->super.coll_gatherv = mca_coll_self_gatherv_intra; - module->super.coll_reduce = mca_coll_self_reduce_intra; - module->super.coll_reduce_scatter = mca_coll_self_reduce_scatter_intra; - module->super.coll_scan = mca_coll_self_scan_intra; - module->super.coll_scatter = mca_coll_self_scatter_intra; - module->super.coll_scatterv = mca_coll_self_scatterv_intra; + module->super.coll_module_disable = mca_coll_self_module_disable; module->super.coll_reduce_local = mca_coll_base_reduce_local; @@ -88,13 +80,73 @@ mca_coll_self_comm_query(struct ompi_communicator_t *comm, return NULL; } +#define SELF_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_self_##__api##_intra, &__module->super, "self"); \ + } while (0) + +#define SELF_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "self"); \ + } \ + } while (0) /* * Init module on the communicator */ -int +static int mca_coll_self_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm) { + mca_coll_self_module_t *sm_module = (mca_coll_self_module_t*)module; + + SELF_INSTALL_COLL_API(comm, sm_module, allgather); + SELF_INSTALL_COLL_API(comm, sm_module, allgatherv); + SELF_INSTALL_COLL_API(comm, sm_module, allreduce); + SELF_INSTALL_COLL_API(comm, sm_module, alltoall); + SELF_INSTALL_COLL_API(comm, sm_module, alltoallv); + SELF_INSTALL_COLL_API(comm, sm_module, alltoallw); + SELF_INSTALL_COLL_API(comm, sm_module, barrier); + SELF_INSTALL_COLL_API(comm, sm_module, bcast); + SELF_INSTALL_COLL_API(comm, sm_module, exscan); + SELF_INSTALL_COLL_API(comm, sm_module, gather); + SELF_INSTALL_COLL_API(comm, sm_module, gatherv); + SELF_INSTALL_COLL_API(comm, sm_module, reduce); + SELF_INSTALL_COLL_API(comm, sm_module, reduce_scatter); + SELF_INSTALL_COLL_API(comm, sm_module, scan); + SELF_INSTALL_COLL_API(comm, sm_module, scatter); + SELF_INSTALL_COLL_API(comm, sm_module, scatterv); + + MCA_COLL_INSTALL_API(comm, reduce_local, mca_coll_base_reduce_local, module, "self"); + return OMPI_SUCCESS; +} + +static int +mca_coll_self_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_self_module_t *sm_module = (mca_coll_self_module_t *)module; + + SELF_UNINSTALL_COLL_API(comm, sm_module, allgather); + SELF_UNINSTALL_COLL_API(comm, sm_module, allgatherv); + SELF_UNINSTALL_COLL_API(comm, sm_module, allreduce); + SELF_UNINSTALL_COLL_API(comm, sm_module, alltoall); + SELF_UNINSTALL_COLL_API(comm, sm_module, alltoallv); + SELF_UNINSTALL_COLL_API(comm, sm_module, alltoallw); + SELF_UNINSTALL_COLL_API(comm, sm_module, barrier); + SELF_UNINSTALL_COLL_API(comm, sm_module, bcast); + SELF_UNINSTALL_COLL_API(comm, sm_module, exscan); + SELF_UNINSTALL_COLL_API(comm, sm_module, gather); + SELF_UNINSTALL_COLL_API(comm, sm_module, gatherv); + SELF_UNINSTALL_COLL_API(comm, sm_module, reduce); + SELF_UNINSTALL_COLL_API(comm, sm_module, reduce_scatter); + SELF_UNINSTALL_COLL_API(comm, sm_module, scan); + SELF_UNINSTALL_COLL_API(comm, sm_module, scatter); + SELF_UNINSTALL_COLL_API(comm, sm_module, scatterv); + return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/sync/Makefile.am b/ompi/mca/coll/sync/Makefile.am index 2f75cd2dfa5..dad4f8f7138 100644 --- a/ompi/mca/coll/sync/Makefile.am +++ b/ompi/mca/coll/sync/Makefile.am @@ -12,6 +12,7 @@ # Copyright (c) 2009 Cisco Systems, Inc. All rights reserved. # Copyright (c) 2016 Intel, Inc. All rights reserved # Copyright (c) 2017 IBM Corporation. All rights reserved. +# Copyright (c) 2024 NVIDIA Corporation. All rights reserved. # $COPYRIGHT$ # # Additional copyrights may follow @@ -19,8 +20,6 @@ # $HEADER$ # -dist_ompidata_DATA = help-coll-sync.txt - sources = \ coll_sync.h \ coll_sync_component.c \ diff --git a/ompi/mca/coll/sync/coll_sync.h b/ompi/mca/coll/sync/coll_sync.h index 76913b9615c..b6617f9ef85 100644 --- a/ompi/mca/coll/sync/coll_sync.h +++ b/ompi/mca/coll/sync/coll_sync.h @@ -10,6 +10,7 @@ * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. * Copyright (c) 2008-2009 Cisco Systems, Inc. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -43,9 +44,6 @@ mca_coll_base_module_t *mca_coll_sync_comm_query(struct ompi_communicator_t *comm, int *priority); -int mca_coll_sync_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - int mca_coll_sync_barrier(struct ompi_communicator_t *comm, mca_coll_base_module_t *module); diff --git a/ompi/mca/coll/sync/coll_sync_module.c b/ompi/mca/coll/sync/coll_sync_module.c index 128cb8c3fa9..65b32f1ba16 100644 --- a/ompi/mca/coll/sync/coll_sync_module.c +++ b/ompi/mca/coll/sync/coll_sync_module.c @@ -13,6 +13,7 @@ * Copyright (c) 2016 Research Organization for Information Science * and Technology (RIST). All rights reserved. * Copyright (c) 2018-2019 Intel, Inc. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -40,6 +41,12 @@ #include "ompi/mca/coll/base/base.h" #include "coll_sync.h" +static int +mca_coll_sync_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); +static int +mca_coll_sync_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); static void mca_coll_sync_module_construct(mca_coll_sync_module_t *module) { @@ -111,16 +118,10 @@ mca_coll_sync_comm_query(struct ompi_communicator_t *comm, /* Choose whether to use [intra|inter] */ sync_module->super.coll_module_enable = mca_coll_sync_module_enable; + sync_module->super.coll_module_disable = mca_coll_sync_module_disable; /* The "all" versions are already synchronous. So no need for an additional barrier there. */ - sync_module->super.coll_allgather = NULL; - sync_module->super.coll_allgatherv = NULL; - sync_module->super.coll_allreduce = NULL; - sync_module->super.coll_alltoall = NULL; - sync_module->super.coll_alltoallv = NULL; - sync_module->super.coll_alltoallw = NULL; - sync_module->super.coll_barrier = NULL; sync_module->super.coll_bcast = mca_coll_sync_bcast; sync_module->super.coll_exscan = mca_coll_sync_exscan; sync_module->super.coll_gather = mca_coll_sync_gather; @@ -134,47 +135,78 @@ mca_coll_sync_comm_query(struct ompi_communicator_t *comm, return &(sync_module->super); } +#define SYNC_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + /* save the current selected collective */ \ + MCA_COLL_SAVE_API(__comm, __api, __module->c_coll.coll_##__api, __module->c_coll.coll_##__api##_module, "sync"); \ + /* install our own */ \ + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_sync_##__api, &__module->super, "sync"); \ + } while (0) + +#define SYNC_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_sync_##__api, __module->c_coll.coll_##__api##_module, "sync"); \ + __module->c_coll.coll_##__api = NULL; \ + __module->c_coll.coll_##__api##_module = NULL; \ + } \ + } while (0) /* * Init module on the communicator */ -int mca_coll_sync_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm) +static int +mca_coll_sync_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) { - bool good = true; - char *msg = NULL; - mca_coll_sync_module_t *s = (mca_coll_sync_module_t*) module; + mca_coll_sync_module_t *sync_module = (mca_coll_sync_module_t*) module; - /* Save the prior layer of coll functions */ - s->c_coll = *comm->c_coll; - -#define CHECK_AND_RETAIN(name) \ - if (NULL == s->c_coll.coll_ ## name ## _module) { \ - good = false; \ - msg = #name; \ - } else if (good) { \ - OBJ_RETAIN(s->c_coll.coll_ ## name ## _module); \ - } + /* The "all" versions are already synchronous. So no need for an + additional barrier there. */ + SYNC_INSTALL_COLL_API(comm, sync_module, bcast); + SYNC_INSTALL_COLL_API(comm, sync_module, gather); + SYNC_INSTALL_COLL_API(comm, sync_module, gatherv); + SYNC_INSTALL_COLL_API(comm, sync_module, reduce); + SYNC_INSTALL_COLL_API(comm, sync_module, reduce_scatter); + SYNC_INSTALL_COLL_API(comm, sync_module, scatter); + SYNC_INSTALL_COLL_API(comm, sync_module, scatterv); - CHECK_AND_RETAIN(bcast); - CHECK_AND_RETAIN(gather); - CHECK_AND_RETAIN(gatherv); - CHECK_AND_RETAIN(reduce); - CHECK_AND_RETAIN(reduce_scatter); - CHECK_AND_RETAIN(scatter); - CHECK_AND_RETAIN(scatterv); if (!OMPI_COMM_IS_INTER(comm)) { /* MPI does not define scan/exscan on intercommunicators */ - CHECK_AND_RETAIN(exscan); - CHECK_AND_RETAIN(scan); + SYNC_INSTALL_COLL_API(comm, sync_module, scan); + SYNC_INSTALL_COLL_API(comm, sync_module, exscan); } - /* All done */ - if (good) { - return OMPI_SUCCESS; + return OMPI_SUCCESS; +} + +static int +mca_coll_sync_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_sync_module_t *sync_module = (mca_coll_sync_module_t*) module; + + /* Save the prior layer of coll functions */ + sync_module->c_coll = *comm->c_coll; + + /* The "all" versions are already synchronous. So no need for an + additional barrier there. */ + SYNC_UNINSTALL_COLL_API(comm, sync_module, bcast); + SYNC_UNINSTALL_COLL_API(comm, sync_module, gather); + SYNC_UNINSTALL_COLL_API(comm, sync_module, gatherv); + SYNC_UNINSTALL_COLL_API(comm, sync_module, reduce); + SYNC_UNINSTALL_COLL_API(comm, sync_module, reduce_scatter); + SYNC_UNINSTALL_COLL_API(comm, sync_module, scatter); + SYNC_UNINSTALL_COLL_API(comm, sync_module, scatterv); + + if (!OMPI_COMM_IS_INTER(comm)) { + /* MPI does not define scan/exscan on intercommunicators */ + SYNC_INSTALL_COLL_API(comm, sync_module, scan); + SYNC_INSTALL_COLL_API(comm, sync_module, exscan); } - opal_show_help("help-coll-sync.txt", "missing collective", true, - ompi_process_info.nodename, - mca_coll_sync_component.priority, msg); - return OMPI_ERR_NOT_FOUND; + + return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/sync/help-coll-sync.txt b/ompi/mca/coll/sync/help-coll-sync.txt deleted file mode 100644 index 4a5c871207e..00000000000 --- a/ompi/mca/coll/sync/help-coll-sync.txt +++ /dev/null @@ -1,22 +0,0 @@ -# -*- text -*- -# -# Copyright (c) 2009 Cisco Systems, Inc. All rights reserved. -# $COPYRIGHT$ -# -# Additional copyrights may follow -# -# $HEADER$ -# -# This is the US/English general help file for Open MPI's sync -# collective component. -# -[missing collective] -The sync collective component in Open MPI was activated on a -communicator where it did not find an underlying collective operation -defined. This usually means that the sync collective module's -priority was not set high enough. Please try increasing sync's -priority. - - Local host: %s - Sync coll module priority: %d - First discovered missing collective: %s diff --git a/ompi/mca/coll/tuned/coll_tuned_module.c b/ompi/mca/coll/tuned/coll_tuned_module.c index f1be43440cc..f83b3ecd9ea 100644 --- a/ompi/mca/coll/tuned/coll_tuned_module.c +++ b/ompi/mca/coll/tuned/coll_tuned_module.c @@ -13,6 +13,7 @@ * Copyright (c) 2016 Intel, Inc. All rights reserved. * Copyright (c) 2018 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -36,6 +37,8 @@ static int tuned_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm); +static int tuned_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); /* * Initial query function that is invoked during MPI_INIT, allowing * this component to disqualify itself if it doesn't support the @@ -89,6 +92,7 @@ ompi_coll_tuned_comm_query(struct ompi_communicator_t *comm, int *priority) * but this would probably add an extra if and funct call to the path */ tuned_module->super.coll_module_enable = tuned_module_enable; + tuned_module->super.coll_module_disable = tuned_module_disable; /* By default stick with the fixed version of the tuned collectives. Later on, * when the module get enabled, set the correct version based on the availability @@ -99,18 +103,13 @@ ompi_coll_tuned_comm_query(struct ompi_communicator_t *comm, int *priority) tuned_module->super.coll_allreduce = ompi_coll_tuned_allreduce_intra_dec_fixed; tuned_module->super.coll_alltoall = ompi_coll_tuned_alltoall_intra_dec_fixed; tuned_module->super.coll_alltoallv = ompi_coll_tuned_alltoallv_intra_dec_fixed; - tuned_module->super.coll_alltoallw = NULL; tuned_module->super.coll_barrier = ompi_coll_tuned_barrier_intra_dec_fixed; tuned_module->super.coll_bcast = ompi_coll_tuned_bcast_intra_dec_fixed; - tuned_module->super.coll_exscan = NULL; tuned_module->super.coll_gather = ompi_coll_tuned_gather_intra_dec_fixed; - tuned_module->super.coll_gatherv = NULL; tuned_module->super.coll_reduce = ompi_coll_tuned_reduce_intra_dec_fixed; tuned_module->super.coll_reduce_scatter = ompi_coll_tuned_reduce_scatter_intra_dec_fixed; tuned_module->super.coll_reduce_scatter_block = ompi_coll_tuned_reduce_scatter_block_intra_dec_fixed; - tuned_module->super.coll_scan = NULL; tuned_module->super.coll_scatter = ompi_coll_tuned_scatter_intra_dec_fixed; - tuned_module->super.coll_scatterv = NULL; return &(tuned_module->super); } @@ -148,8 +147,26 @@ ompi_coll_tuned_forced_getvalues( enum COLLTYPE type, return (MPI_SUCCESS); } -#define COLL_TUNED_EXECUTE_IF_DYNAMIC(TMOD, TYPE, EXECUTE) \ +#define TUNED_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__module->super.coll_##__api) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "tuned"); \ + } \ + } while (0) + +#define TUNED_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "tuned"); \ + } \ + } while (0) + +#define COLL_TUNED_EXECUTE_IF_DYNAMIC(TMOD, TYPE, EXECUTE) \ + do { \ int need_dynamic_decision = 0; \ ompi_coll_tuned_forced_getvalues( (TYPE), &((TMOD)->user_forced[(TYPE)]) ); \ (TMOD)->com_rules[(TYPE)] = NULL; \ @@ -168,7 +185,7 @@ ompi_coll_tuned_forced_getvalues( enum COLLTYPE type, OPAL_OUTPUT((ompi_coll_tuned_stream,"coll:tuned: enable dynamic selection for "#TYPE)); \ EXECUTE; \ } \ - } + } while(0) /* * Init module on the communicator @@ -249,6 +266,23 @@ tuned_module_enable( mca_coll_base_module_t *module, COLL_TUNED_EXECUTE_IF_DYNAMIC(tuned_module, SCATTERV, tuned_module->super.coll_scatterv = NULL); } + TUNED_INSTALL_COLL_API(comm, tuned_module, allgather); + TUNED_INSTALL_COLL_API(comm, tuned_module, allgatherv); + TUNED_INSTALL_COLL_API(comm, tuned_module, allreduce); + TUNED_INSTALL_COLL_API(comm, tuned_module, alltoall); + TUNED_INSTALL_COLL_API(comm, tuned_module, alltoallv); + TUNED_INSTALL_COLL_API(comm, tuned_module, alltoallw); + TUNED_INSTALL_COLL_API(comm, tuned_module, barrier); + TUNED_INSTALL_COLL_API(comm, tuned_module, bcast); + TUNED_INSTALL_COLL_API(comm, tuned_module, exscan); + TUNED_INSTALL_COLL_API(comm, tuned_module, gather); + TUNED_INSTALL_COLL_API(comm, tuned_module, gatherv); + TUNED_INSTALL_COLL_API(comm, tuned_module, reduce); + TUNED_INSTALL_COLL_API(comm, tuned_module, reduce_scatter); + TUNED_INSTALL_COLL_API(comm, tuned_module, reduce_scatter_block); + TUNED_INSTALL_COLL_API(comm, tuned_module, scan); + TUNED_INSTALL_COLL_API(comm, tuned_module, scatter); + TUNED_INSTALL_COLL_API(comm, tuned_module, scatterv); /* general n fan out tree */ data->cached_ntree = NULL; @@ -273,3 +307,30 @@ tuned_module_enable( mca_coll_base_module_t *module, OPAL_OUTPUT((ompi_coll_tuned_stream,"coll:tuned:module_init Tuned is in use")); return OMPI_SUCCESS; } + +static int +tuned_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_tuned_module_t *tuned_module = (mca_coll_tuned_module_t *) module; + + TUNED_UNINSTALL_COLL_API(comm, tuned_module, allgather); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, allgatherv); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, allreduce); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, alltoall); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, alltoallv); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, alltoallw); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, barrier); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, bcast); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, exscan); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, gather); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, gatherv); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, reduce); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, reduce_scatter); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, reduce_scatter_block); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, scan); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, scatter); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, scatterv); + + return OMPI_SUCCESS; +} diff --git a/ompi/mca/coll/ucc/coll_ucc_module.c b/ompi/mca/coll/ucc/coll_ucc_module.c index 32861d315de..35bceac3f7e 100644 --- a/ompi/mca/coll/ucc/coll_ucc_module.c +++ b/ompi/mca/coll/ucc/coll_ucc_module.c @@ -1,8 +1,8 @@ /** - * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 Amazon.com, Inc. or its affiliates. * All Rights reserved. - * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2022-2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -16,8 +16,6 @@ #include "ompi/mca/coll/base/coll_tags.h" #include "ompi/mca/pml/pml.h" -#define OBJ_RELEASE_IF_NOT_NULL( obj ) if( NULL != (obj) ) OBJ_RELEASE( obj ); - static int ucc_comm_attr_keyval; /* * Initial query function that is invoked during MPI_INIT, allowing @@ -107,80 +105,9 @@ static void mca_coll_ucc_module_destruct(mca_coll_ucc_module_t *ucc_module) UCC_ERROR("ucc ompi_attr_free_keyval failed"); } } - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_allreduce_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallreduce_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_barrier_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ibarrier_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_bcast_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ibcast_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_alltoall_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ialltoall_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_alltoallv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ialltoallv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_allgather_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallgather_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_allgatherv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallgatherv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_gather_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_igather_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_gatherv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_igatherv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_scatter_block_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_scatter_block_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_scatterv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iscatterv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iscatter_module); mca_coll_ucc_module_clear(ucc_module); } -#define SAVE_PREV_COLL_API(__api) do { \ - ucc_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \ - ucc_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \ - if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \ - return OMPI_ERROR; \ - } \ - OBJ_RETAIN(ucc_module->previous_ ## __api ## _module); \ - } while(0) - -static int mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module) -{ - ompi_communicator_t *comm = ucc_module->comm; - SAVE_PREV_COLL_API(allreduce); - SAVE_PREV_COLL_API(iallreduce); - SAVE_PREV_COLL_API(barrier); - SAVE_PREV_COLL_API(ibarrier); - SAVE_PREV_COLL_API(bcast); - SAVE_PREV_COLL_API(ibcast); - SAVE_PREV_COLL_API(alltoall); - SAVE_PREV_COLL_API(ialltoall); - SAVE_PREV_COLL_API(alltoallv); - SAVE_PREV_COLL_API(ialltoallv); - SAVE_PREV_COLL_API(allgather); - SAVE_PREV_COLL_API(iallgather); - SAVE_PREV_COLL_API(allgatherv); - SAVE_PREV_COLL_API(iallgatherv); - SAVE_PREV_COLL_API(reduce); - SAVE_PREV_COLL_API(ireduce); - SAVE_PREV_COLL_API(gather); - SAVE_PREV_COLL_API(igather); - SAVE_PREV_COLL_API(gatherv); - SAVE_PREV_COLL_API(igatherv); - SAVE_PREV_COLL_API(reduce_scatter_block); - SAVE_PREV_COLL_API(ireduce_scatter_block); - SAVE_PREV_COLL_API(reduce_scatter); - SAVE_PREV_COLL_API(ireduce_scatter); - SAVE_PREV_COLL_API(scatterv); - SAVE_PREV_COLL_API(iscatterv); - SAVE_PREV_COLL_API(scatter); - SAVE_PREV_COLL_API(iscatter); - return OMPI_SUCCESS; -} - /* ** Communicator free callback */ @@ -440,6 +367,50 @@ static inline ucc_ep_map_t get_rank_map(struct ompi_communicator_t *comm) return map; } + +#define UCC_INSTALL_COLL_API(__comm, __ucc_module, __COLL, __api) \ + do \ + { \ + if ((mca_coll_ucc_component.ucc_lib_attr.coll_types & UCC_COLL_TYPE_##__COLL)) \ + { \ + if (mca_coll_ucc_component.cts_requested & UCC_COLL_TYPE_##__COLL) \ + { \ + MCA_COLL_SAVE_API(__comm, __api, (__ucc_module)->previous_##__api, (__ucc_module)->previous_##__api##_module, "ucc"); \ + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_ucc_##__api, &__ucc_module->super, "ucc"); \ + (__ucc_module)->super.coll_##__api = mca_coll_ucc_##__api; \ + } \ + if (mca_coll_ucc_component.nb_cts_requested & UCC_COLL_TYPE_##__COLL) \ + { \ + MCA_COLL_SAVE_API(__comm, i##__api, (__ucc_module)->previous_i##__api, (__ucc_module)->previous_i##__api##_module, "ucc"); \ + MCA_COLL_INSTALL_API(__comm, i##__api, mca_coll_ucc_i##__api, &__ucc_module->super, "ucc"); \ + (__ucc_module)->super.coll_i##__api = mca_coll_ucc_i##__api; \ + } \ + } \ + } while (0) + +static int mca_coll_ucc_replace_coll_handlers(mca_coll_ucc_module_t *ucc_module) +{ + ompi_communicator_t *comm = ucc_module->comm; + + UCC_INSTALL_COLL_API(comm, ucc_module, ALLREDUCE, allreduce); + UCC_INSTALL_COLL_API(comm, ucc_module, BARRIER, barrier); + UCC_INSTALL_COLL_API(comm, ucc_module, BCAST, bcast); + UCC_INSTALL_COLL_API(comm, ucc_module, ALLTOALL, alltoall); + UCC_INSTALL_COLL_API(comm, ucc_module, ALLTOALLV, alltoallv); + UCC_INSTALL_COLL_API(comm, ucc_module, ALLGATHER, allgather); + UCC_INSTALL_COLL_API(comm, ucc_module, ALLGATHERV, allgatherv); + UCC_INSTALL_COLL_API(comm, ucc_module, REDUCE, reduce); + + UCC_INSTALL_COLL_API(comm, ucc_module, GATHER, gather); + UCC_INSTALL_COLL_API(comm, ucc_module, GATHERV, gatherv); + UCC_INSTALL_COLL_API(comm, ucc_module, REDUCE_SCATTER_BLOCK, reduce_scatter_block); + UCC_INSTALL_COLL_API(comm, ucc_module, REDUCE_SCATTER, reduce_scatter); + UCC_INSTALL_COLL_API(comm, ucc_module, SCATTER, scatter); + UCC_INSTALL_COLL_API(comm, ucc_module, SCATTERV, scatterv); + + return OMPI_SUCCESS; +} + /* * Initialize module on the communicator */ @@ -470,11 +441,6 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module, (void*)comm, (long long unsigned)team_params.id, ompi_comm_size(comm)); - if (OMPI_SUCCESS != mca_coll_ucc_save_coll_handlers(ucc_module)){ - UCC_ERROR("mca_coll_ucc_save_coll_handlers failed"); - goto err; - } - if (UCC_OK != ucc_team_create_post(&cm->ucc_context, 1, &team_params, &ucc_module->ucc_team)) { UCC_ERROR("ucc_team_create_post failed"); @@ -489,12 +455,18 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module, goto err; } + if (OMPI_SUCCESS != mca_coll_ucc_replace_coll_handlers(ucc_module)) { + UCC_ERROR("mca_coll_ucc_replace_coll_handlers failed"); + goto err; + } + rc = ompi_attr_set_c(COMM_ATTR, comm, &comm->c_keyhash, ucc_comm_attr_keyval, (void *)module, false); if (OMPI_SUCCESS != rc) { UCC_ERROR("ucc ompi_attr_set_c failed"); goto err; } + return OMPI_SUCCESS; err: @@ -504,22 +476,53 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module, return OMPI_ERROR; } +#define UCC_UNINSTALL_COLL_API(__comm, __ucc_module, __api) \ + do \ + { \ + if (&(__ucc_module)->super == (__comm)->c_coll->coll_##__api##_module) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, (__ucc_module)->previous_##__api, (__ucc_module)->previous_##__api##_module, "ucc"); \ + (__ucc_module)->previous_##__api = NULL; \ + (__ucc_module)->previous_##__api##_module = NULL; \ + } \ + } while (0) + +/** + * The disable will be called once per collective module, in the reverse order + * in which enable has been called. This reverse order allows the module to properly + * unregister the collective function pointers they provide for the communicator. + */ +static int +mca_coll_ucc_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module; + UCC_UNINSTALL_COLL_API(comm, ucc_module, allreduce); + UCC_UNINSTALL_COLL_API(comm, ucc_module, iallreduce); + UCC_UNINSTALL_COLL_API(comm, ucc_module, barrier); + UCC_UNINSTALL_COLL_API(comm, ucc_module, ibarrier); + UCC_UNINSTALL_COLL_API(comm, ucc_module, bcast); + UCC_UNINSTALL_COLL_API(comm, ucc_module, ibcast); + UCC_UNINSTALL_COLL_API(comm, ucc_module, alltoall); + UCC_UNINSTALL_COLL_API(comm, ucc_module, ialltoall); + UCC_UNINSTALL_COLL_API(comm, ucc_module, alltoallv); + UCC_UNINSTALL_COLL_API(comm, ucc_module, ialltoallv); + UCC_UNINSTALL_COLL_API(comm, ucc_module, allgather); + UCC_UNINSTALL_COLL_API(comm, ucc_module, iallgather); + UCC_UNINSTALL_COLL_API(comm, ucc_module, allgatherv); + UCC_UNINSTALL_COLL_API(comm, ucc_module, iallgatherv); + UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce); + UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce); + UCC_UNINSTALL_COLL_API(comm, ucc_module, gather); + UCC_UNINSTALL_COLL_API(comm, ucc_module, gatherv); + UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter_block); + UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter); + UCC_UNINSTALL_COLL_API(comm, ucc_module, scatter); + UCC_UNINSTALL_COLL_API(comm, ucc_module, scatterv); + + return OMPI_SUCCESS; +} -#define SET_COLL_PTR(_module, _COLL, _coll) do { \ - _module->super.coll_ ## _coll = NULL; \ - _module->super.coll_i ## _coll = NULL; \ - if ((mca_coll_ucc_component.ucc_lib_attr.coll_types & \ - UCC_COLL_TYPE_ ## _COLL)) { \ - if (mca_coll_ucc_component.cts_requested & \ - UCC_COLL_TYPE_ ## _COLL) { \ - _module->super.coll_ ## _coll = mca_coll_ucc_ ## _coll; \ - } \ - if (mca_coll_ucc_component.nb_cts_requested & \ - UCC_COLL_TYPE_ ## _COLL) { \ - _module->super.coll_i ## _coll = mca_coll_ucc_i ## _coll; \ - } \ - } \ - } while(0) /* * Invoked when there's a new communicator that has been created. @@ -554,23 +557,11 @@ mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority) cm->ucc_enable = 0; return NULL; } - ucc_module->comm = comm; - ucc_module->super.coll_module_enable = mca_coll_ucc_module_enable; - *priority = cm->ucc_priority; - SET_COLL_PTR(ucc_module, BARRIER, barrier); - SET_COLL_PTR(ucc_module, BCAST, bcast); - SET_COLL_PTR(ucc_module, ALLREDUCE, allreduce); - SET_COLL_PTR(ucc_module, ALLTOALL, alltoall); - SET_COLL_PTR(ucc_module, ALLTOALLV, alltoallv); - SET_COLL_PTR(ucc_module, REDUCE, reduce); - SET_COLL_PTR(ucc_module, ALLGATHER, allgather); - SET_COLL_PTR(ucc_module, ALLGATHERV, allgatherv); - SET_COLL_PTR(ucc_module, GATHER, gather); - SET_COLL_PTR(ucc_module, GATHERV, gatherv); - SET_COLL_PTR(ucc_module, REDUCE_SCATTER, reduce_scatter_block); - SET_COLL_PTR(ucc_module, REDUCE_SCATTERV, reduce_scatter); - SET_COLL_PTR(ucc_module, SCATTERV, scatterv); - SET_COLL_PTR(ucc_module, SCATTER, scatter); + ucc_module->comm = comm; + ucc_module->super.coll_module_enable = mca_coll_ucc_module_enable; + ucc_module->super.coll_module_disable = mca_coll_ucc_module_disable; + *priority = cm->ucc_priority; + return &ucc_module->super; } diff --git a/ompi/request/request_dbg.h b/ompi/request/request_dbg.h index e6aa3757d06..5929374ade4 100644 --- a/ompi/request/request_dbg.h +++ b/ompi/request/request_dbg.h @@ -1,6 +1,7 @@ /* -*- Mode: C; c-basic-offset:4 ; -*- */ /* * Copyright (c) 2009 Sun Microsystems, Inc. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -17,7 +18,7 @@ */ /** - * Enum inidicating the type of the request + * Enum indicating the type of the request */ typedef enum { OMPI_REQUEST_PML, /**< MPI point-to-point request */