diff --git a/.github/workflows/compile-ze.yaml b/.github/workflows/compile-ze.yaml new file mode 100644 index 00000000000..dcfef7094c6 --- /dev/null +++ b/.github/workflows/compile-ze.yaml @@ -0,0 +1,31 @@ +name: OneAPI ZE + +on: [pull_request] + +jobs: + compile-ze: + runs-on: ubuntu-22.04 + steps: + - name: Install dependencies + run: | + sudo apt update + sudo apt install -y --no-install-recommends wget lsb-core software-properties-common gpg curl cmake git + - name: Build OneAPI ZE + run: | + git clone https://github.com/oneapi-src/level-zero.git + cd level-zero + mkdir build + cd build + cmake ../ -DCMAKE_INSTALL_PREFIX=/opt/ze + sudo make -j install + - uses: actions/checkout@v3 + with: + submodules: recursive + - name: Build Open MPI + run: | + ./autogen.pl + # + # we have to disable romio as its old ze stuff doesn't compile with supported ZE API + # + ./configure --prefix=${PWD}/install --disable-mpi-fortran --disable-io-romio --disable-oshmem --with-ze + make -j diff --git a/.gitmodules b/.gitmodules index d50b377cc73..6914849f45a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "prrte"] path = 3rd-party/prrte - url = ../../openpmix/prrte + url = ../../open-mpi/prrte branch = master [submodule "openpmix"] path = 3rd-party/openpmix diff --git a/3rd-party/openpmix b/3rd-party/openpmix index 6f81bfd163f..213956cf00f 160000 --- a/3rd-party/openpmix +++ b/3rd-party/openpmix @@ -1 +1 @@ -Subproject commit 6f81bfd163f3275d2b0630974968c82759dd4439 +Subproject commit 213956cf00ff164230de06b9887a9412bf1e1dad diff --git a/3rd-party/prrte b/3rd-party/prrte index 4f27008906d..1d867e84981 160000 --- a/3rd-party/prrte +++ b/3rd-party/prrte @@ -1 +1 @@ -Subproject commit 4f27008906d96845e22df6502d6a9a29d98dec83 +Subproject commit 1d867e84981077bffda9ad9d44ff415a3f6d91c4 diff --git a/config/ompi_fortran_check_asynchronous.m4 b/config/ompi_fortran_check_asynchronous.m4 index 0cc3c84bfe5..62c53159c50 100644 --- a/config/ompi_fortran_check_asynchronous.m4 +++ b/config/ompi_fortran_check_asynchronous.m4 @@ -11,6 +11,8 @@ dnl University of Stuttgart. All rights reserved. dnl Copyright (c) 2004-2005 The Regents of the University of California. dnl All rights reserved. dnl Copyright (c) 2010-2014 Cisco Systems, Inc. All rights reserved. +dnl Copyright (c) 2024 Research Organization for Information Science +dnl and Technology (RIST). All rights reserved. dnl $COPYRIGHT$ dnl dnl Additional copyrights may follow @@ -35,6 +37,10 @@ SUBROUTINE binky(buf) REAL, DIMENSION(*), ASYNCHRONOUS :: buf END SUBROUTINE END INTERFACE +CONTAINS +SUBROUTINE wookie(buf) + REAL, DIMENSION(*), ASYNCHRONOUS :: buf +END SUBROUTINE END MODULE asynch_mod]])], [AS_VAR_SET(asynchronous_var, yes)], [AS_VAR_SET(asynchronous_var, no)]) diff --git a/config/ompi_fortran_check_ignore_tkr.m4 b/config/ompi_fortran_check_ignore_tkr.m4 index 34d42f90847..3686bca82e8 100644 --- a/config/ompi_fortran_check_ignore_tkr.m4 +++ b/config/ompi_fortran_check_ignore_tkr.m4 @@ -14,6 +14,8 @@ dnl Copyright (c) 2007 Los Alamos National Security, LLC. All rights dnl reserved. dnl Copyright (c) 2007 Sun Microsystems, Inc. All rights reserved. dnl Copyright (c) 2009-2015 Cisco Systems, Inc. All rights reserved. +dnl Copyright (c) 2024 Research Organization for Information Science +dnl and Technology (RIST). All rights reserved. dnl $COPYRIGHT$ dnl dnl Additional copyrights may follow @@ -82,6 +84,12 @@ AC_DEFUN([_OMPI_FORTRAN_CHECK_IGNORE_TKR], [ [!GCC\$ ATTRIBUTES NO_ARG_CHECK ::], [type(*), dimension(*)], [!GCC\$ ATTRIBUTES NO_ARG_CHECK], [internal_ignore_tkr_happy=1], [internal_ignore_tkr_happy=0])]) + # LLVM compilers + AS_IF([test $internal_ignore_tkr_happy -eq 0], + [OMPI_FORTRAN_CHECK_IGNORE_TKR_SUB( + [!DIR\$ IGNORE_TKR], [type(*)], + [!DIR\$ IGNORE_TKR], + [internal_ignore_tkr_happy=1], [internal_ignore_tkr_happy=0])]) # Intel compilers AS_IF([test $internal_ignore_tkr_happy -eq 0], [OMPI_FORTRAN_CHECK_IGNORE_TKR_SUB( @@ -133,6 +141,7 @@ AC_DEFUN([OMPI_FORTRAN_CHECK_IGNORE_TKR_SUB], [ AC_MSG_CHECKING([for Fortran compiler support of $3]) AC_COMPILE_IFELSE(AC_LANG_PROGRAM([],[[! ! Autoconf puts "program main" at the top + implicit none interface subroutine force_assumed_shape(a, count) @@ -157,6 +166,7 @@ AC_DEFUN([OMPI_FORTRAN_CHECK_IGNORE_TKR_SUB], [ complex, pointer, dimension(:,:) :: ptr target :: buffer3 integer :: buffer4 + integer :: a ptr => buffer3 ! Set some known values (somewhat irrelevant for this test, but just be @@ -189,8 +199,23 @@ AC_DEFUN([OMPI_FORTRAN_CHECK_IGNORE_TKR_SUB], [ call foo(a, count) end subroutine force_assumed_shape + module check_ignore_tkr + interface + subroutine foobar(buffer, count) + $1 buffer + $2, intent(in) :: buffer + integer, intent(in) :: count + end subroutine foobar + end interface + end module + + subroutine bar(var) + use check_ignore_tkr + implicit none + real, intent(inout) :: var(:, :, :) + + call foobar(var(1,1,1), 1) ! Autoconf puts "end" after the last line - subroutine bogus ]]), [msg=yes ompi_fortran_ignore_tkr_predecl="$1" diff --git a/config/opal_check_cuda.m4 b/config/opal_check_cuda.m4 index 87f5dc01ef1..67d0cbeede1 100644 --- a/config/opal_check_cuda.m4 +++ b/config/opal_check_cuda.m4 @@ -59,8 +59,8 @@ AC_ARG_WITH([cuda-libdir], [Search for CUDA libraries in DIR])], [], [AS_IF([test -d "$with_cuda"], - [with_cuda_libdir=$(dirname $(find $with_cuda -name libcuda.so 2> /dev/null) 2> /dev/null)], - [with_cuda_libdir=$(dirname $(find /usr/local/cuda -name libcuda.so 2> /dev/null) 2> /dev/null)]) + [with_cuda_libdir=$(dirname $(find -H $with_cuda -name libcuda.so 2> /dev/null) 2> /dev/null)], + [with_cuda_libdir=$(dirname $(find -H /usr/local/cuda -name libcuda.so 2> /dev/null) 2> /dev/null)]) ]) # Note that CUDA support is off by default. To turn it on, the user has to diff --git a/docs/getting-help.rst b/docs/getting-help.rst index 535a020963e..ca9823f8a57 100644 --- a/docs/getting-help.rst +++ b/docs/getting-help.rst @@ -120,7 +120,7 @@ information (adjust as necessary for your specific environment): # Fill in the options you want to pass to configure here options="" ./configure $options 2>&1 | tee $dir/config.out - tar -cf - `find . -name config.log` | tar -x -C $dir - + tar -cf - `find . -name config.log` | tar -x -C $dir # Build and install Open MPI make V=1 all 2>&1 | tee $dir/make.out diff --git a/docs/man-openmpi/man3/MPI_Send.3.rst b/docs/man-openmpi/man3/MPI_Send.3.rst index 69af46f6048..32219376939 100644 --- a/docs/man-openmpi/man3/MPI_Send.3.rst +++ b/docs/man-openmpi/man3/MPI_Send.3.rst @@ -53,7 +53,7 @@ Fortran 2008 Syntax INPUT PARAMETERS ---------------- * ``buf``: Initial address of send buffer (choice). -* ``count``: Number of elements send (nonnegative integer). +* ``count``: Number of elements in send buffer (nonnegative integer). * ``datatype``: Datatype of each send buffer element (handle). * ``dest``: Rank of destination (integer). * ``tag``: Message tag (integer). diff --git a/ompi/communicator/comm.c b/ompi/communicator/comm.c index d3b8e4ae243..545356d958d 100644 --- a/ompi/communicator/comm.c +++ b/ompi/communicator/comm.c @@ -707,11 +707,6 @@ int ompi_comm_split_with_info( ompi_communicator_t* comm, int color, int key, /* Activate the communicator and init coll-component */ rc = ompi_comm_activate (&newcomp, comm, NULL, NULL, NULL, false, mode); - /* MPI-4 §7.4.4 requires us to remove all unknown keys from the info object */ - if (NULL != newcomp->super.s_info) { - opal_info_remove_unreferenced(newcomp->super.s_info); - } - exit: free ( results ); free ( sorted ); @@ -1028,9 +1023,6 @@ static int ompi_comm_split_type_core(ompi_communicator_t *comm, goto exit; } - /* MPI-4 §7.4.4 requires us to remove all unknown keys from the info object */ - opal_info_remove_unreferenced(newcomp->super.s_info); - /* TODO: there probably is better way to handle this case without throwing away the * intermediate communicator. */ rc = ompi_comm_split (newcomp, local_split_type, key, newcomm, false); @@ -1363,9 +1355,6 @@ int ompi_comm_dup_with_info ( ompi_communicator_t * comm, opal_info_t *info, omp return rc; } - /* MPI-4 §7.4.4 requires us to remove all unknown keys from the info object */ - opal_info_remove_unreferenced(newcomp->super.s_info); - *newcomm = newcomp; return MPI_SUCCESS; } @@ -1522,8 +1511,6 @@ static int ompi_comm_idup_with_info_finish (ompi_comm_request_t *request) { ompi_comm_idup_with_info_context_t *context = (ompi_comm_idup_with_info_context_t *) request->context; - /* MPI-4 §7.4.4 requires us to remove all unknown keys from the info object */ - opal_info_remove_unreferenced(context->newcomp->super.s_info); /* done */ return MPI_SUCCESS; diff --git a/ompi/file/file.c b/ompi/file/file.c index f8c009764a8..9026fbea751 100644 --- a/ompi/file/file.c +++ b/ompi/file/file.c @@ -138,9 +138,6 @@ int ompi_file_open(struct ompi_communicator_t *comm, const char *filename, return ret; } - /* MPI-4 §14.2.8 requires us to remove all unknown keys from the info object */ - opal_info_remove_unreferenced(file->super.s_info); - /* All done */ *fh = file; diff --git a/ompi/info/info.c b/ompi/info/info.c index e33fc9a6684..577910da840 100644 --- a/ompi/info/info.c +++ b/ompi/info/info.c @@ -243,7 +243,7 @@ int ompi_mpiinfo_init_env(int argc, char *argv[], ompi_info_t *info) // related calls: int ompi_info_dup (ompi_info_t *info, ompi_info_t **newinfo) { - return opal_info_dup (&(info->super), (opal_info_t **)newinfo); + return opal_info_dup_public (&(info->super), (opal_info_t **)newinfo); } int ompi_info_set (ompi_info_t *info, const char *key, const char *value) { return opal_info_set (&(info->super), key, value); diff --git a/ompi/mca/coll/accelerator/Makefile.am b/ompi/mca/coll/accelerator/Makefile.am index eaf81137602..e3621c1d05a 100644 --- a/ompi/mca/coll/accelerator/Makefile.am +++ b/ompi/mca/coll/accelerator/Makefile.am @@ -2,7 +2,7 @@ # Copyright (c) 2014 The University of Tennessee and The University # of Tennessee Research Foundation. All rights # reserved. -# Copyright (c) 2014 NVIDIA Corporation. All rights reserved. +# Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved. # Copyright (c) 2017 IBM Corporation. All rights reserved. # $COPYRIGHT$ # @@ -10,7 +10,6 @@ # # $HEADER$ # -dist_ompidata_DATA = help-mpi-coll-accelerator.txt sources = coll_accelerator_module.c coll_accelerator_reduce.c coll_accelerator_allreduce.c \ coll_accelerator_reduce_scatter_block.c coll_accelerator_component.c \ diff --git a/ompi/mca/coll/accelerator/coll_accelerator.h b/ompi/mca/coll/accelerator/coll_accelerator.h index c840a3c2d27..5e8eb64794f 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator.h +++ b/ompi/mca/coll/accelerator/coll_accelerator.h @@ -2,7 +2,7 @@ * Copyright (c) 2014 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. - * Copyright (c) 2014-2015 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved. * Copyright (c) 2024 Triad National Security, LLC. All rights reserved. * $COPYRIGHT$ * @@ -38,9 +38,6 @@ mca_coll_base_module_t *mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm, int *priority); -int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - int mca_coll_accelerator_allreduce(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, diff --git a/ompi/mca/coll/accelerator/coll_accelerator_module.c b/ompi/mca/coll/accelerator/coll_accelerator_module.c index 505d31c1c07..4fe1603a8aa 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator_module.c +++ b/ompi/mca/coll/accelerator/coll_accelerator_module.c @@ -2,7 +2,7 @@ * Copyright (c) 2014-2017 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. - * Copyright (c) 2014 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved. * Copyright (c) 2019 Research Organization for Information Science * and Technology (RIST). All rights reserved. * Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. @@ -32,30 +32,21 @@ #include "ompi/mca/coll/base/base.h" #include "coll_accelerator.h" +static int +mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); +static int +mca_coll_accelerator_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); static void mca_coll_accelerator_module_construct(mca_coll_accelerator_module_t *module) { memset(&(module->c_coll), 0, sizeof(module->c_coll)); } -static void mca_coll_accelerator_module_destruct(mca_coll_accelerator_module_t *module) -{ - OBJ_RELEASE(module->c_coll.coll_allreduce_module); - OBJ_RELEASE(module->c_coll.coll_reduce_module); - OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module); - OBJ_RELEASE(module->c_coll.coll_scatter_module); - /* If the exscan module is not NULL, then this was an - intracommunicator, and therefore scan will have a module as - well. */ - if (NULL != module->c_coll.coll_exscan_module) { - OBJ_RELEASE(module->c_coll.coll_exscan_module); - OBJ_RELEASE(module->c_coll.coll_scan_module); - } -} - OBJ_CLASS_INSTANCE(mca_coll_accelerator_module_t, mca_coll_base_module_t, mca_coll_accelerator_module_construct, - mca_coll_accelerator_module_destruct); + NULL); /* @@ -99,66 +90,82 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm, /* Choose whether to use [intra|inter] */ accelerator_module->super.coll_module_enable = mca_coll_accelerator_module_enable; + accelerator_module->super.coll_module_disable = mca_coll_accelerator_module_disable; - accelerator_module->super.coll_allgather = NULL; - accelerator_module->super.coll_allgatherv = NULL; accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce; - accelerator_module->super.coll_alltoall = NULL; - accelerator_module->super.coll_alltoallv = NULL; - accelerator_module->super.coll_alltoallw = NULL; - accelerator_module->super.coll_barrier = NULL; - accelerator_module->super.coll_bcast = NULL; - accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan; - accelerator_module->super.coll_gather = NULL; - accelerator_module->super.coll_gatherv = NULL; accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce; - accelerator_module->super.coll_reduce_scatter = NULL; accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block; - accelerator_module->super.coll_scan = mca_coll_accelerator_scan; - accelerator_module->super.coll_scatter = NULL; - accelerator_module->super.coll_scatterv = NULL; + if (!OMPI_COMM_IS_INTER(comm)) { + accelerator_module->super.coll_scan = mca_coll_accelerator_scan; + accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan; + } return &(accelerator_module->super); } +#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if ((__comm)->c_coll->coll_##__api) \ + { \ + MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \ + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \ + } \ + else \ + { \ + opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \ + "cuda", #__api, ompi_process_info.nodename, \ + mca_coll_accelerator_component.priority); \ + } \ + } while (0) + +#define ACCELERATOR_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (&(__module)->super == (__comm)->c_coll->coll_##__api##_module) { \ + MCA_COLL_INSTALL_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \ + (__module)->c_coll.coll_##__api##_module = NULL; \ + (__module)->c_coll.coll_##__api = NULL; \ + } \ + } while (0) + /* - * Init module on the communicator + * Init/Fini module on the communicator */ -int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm) +static int +mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) { - bool good = true; - char *msg = NULL; mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module; -#define CHECK_AND_RETAIN(src, dst, name) \ - if (NULL == (src)->c_coll->coll_ ## name ## _module) { \ - good = false; \ - msg = #name; \ - } else if (good) { \ - (dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module; \ - (dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name; \ - OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module); \ - } - - CHECK_AND_RETAIN(comm, s, allreduce); - CHECK_AND_RETAIN(comm, s, reduce); - CHECK_AND_RETAIN(comm, s, reduce_scatter_block); - CHECK_AND_RETAIN(comm, s, scatter); + ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce); + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce); + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block); if (!OMPI_COMM_IS_INTER(comm)) { /* MPI does not define scan/exscan on intercommunicators */ - CHECK_AND_RETAIN(comm, s, exscan); - CHECK_AND_RETAIN(comm, s, scan); + ACCELERATOR_INSTALL_COLL_API(comm, s, exscan); + ACCELERATOR_INSTALL_COLL_API(comm, s, scan); } - /* All done */ - if (good) { - return OMPI_SUCCESS; - } - opal_show_help("help-mpi-coll-accelerator.txt", "missing collective", true, - ompi_process_info.nodename, - mca_coll_accelerator_component.priority, msg); - return OMPI_ERR_NOT_FOUND; + return OMPI_SUCCESS; } +static int +mca_coll_accelerator_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module; + + ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce); + ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce); + ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block); + if (!OMPI_COMM_IS_INTER(comm)) + { + /* MPI does not define scan/exscan on intercommunicators */ + ACCELERATOR_UNINSTALL_COLL_API(comm, s, exscan); + ACCELERATOR_UNINSTALL_COLL_API(comm, s, scan); + } + + return OMPI_SUCCESS; +} diff --git a/ompi/mca/coll/accelerator/help-mpi-coll-accelerator.txt b/ompi/mca/coll/accelerator/help-mpi-coll-accelerator.txt deleted file mode 100644 index abc39d08e12..00000000000 --- a/ompi/mca/coll/accelerator/help-mpi-coll-accelerator.txt +++ /dev/null @@ -1,29 +0,0 @@ -# -*- text -*- -# -# Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana -# University Research and Technology -# Corporation. All rights reserved. -# Copyright (c) 2004-2005 The University of Tennessee and The University -# of Tennessee Research Foundation. All rights -# reserved. -# Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, -# University of Stuttgart. All rights reserved. -# Copyright (c) 2004-2005 The Regents of the University of California. -# All rights reserved. -# Copyright (c) 2014 NVIDIA Corporation. All rights reserved. -# Copyright (c) 2024 Triad National Security, LLC. All rights reserved. -# $COPYRIGHT$ -# -# Additional copyrights may follow -# -# $HEADER$ -# -# This is the US/English general help file for Open MPI's accelerator -# collective reduction component. -# -[missing collective] -There was a problem while initializing support for the accelerator reduction operations. -hostname: %s -priority: %d -collective: %s -# diff --git a/ompi/mca/coll/adapt/coll_adapt_module.c b/ompi/mca/coll/adapt/coll_adapt_module.c index 12e79d72d4a..abd27d09870 100644 --- a/ompi/mca/coll/adapt/coll_adapt_module.c +++ b/ompi/mca/coll/adapt/coll_adapt_module.c @@ -5,6 +5,7 @@ * Copyright (c) 2021 Triad National Security, LLC. All rights * reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * * $COPYRIGHT$ * @@ -83,25 +84,41 @@ OBJ_CLASS_INSTANCE(mca_coll_adapt_module_t, adapt_module_construct, adapt_module_destruct); -/* - * In this macro, the following variables are supposed to have been declared - * in the caller: - * . ompi_communicator_t *comm - * . mca_coll_adapt_module_t *adapt_module - */ -#define ADAPT_SAVE_PREV_COLL_API(__api) \ - do { \ - adapt_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \ - adapt_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \ - if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \ - opal_output_verbose(1, ompi_coll_base_framework.framework_output, \ - "(%s/%s): no underlying " # __api"; disqualifying myself", \ - ompi_comm_print_cid(comm), comm->c_name); \ - return OMPI_ERROR; \ - } \ - OBJ_RETAIN(adapt_module->previous_ ## __api ## _module); \ - } while(0) - +#define ADAPT_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__module->super.coll_##__api) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \ + } \ + } while (0) +#define ADAPT_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "adapt"); \ + } \ + } while (0) +#define ADAPT_INSTALL_AND_SAVE_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api && __comm->c_coll->coll_##__api##_module) \ + { \ + MCA_COLL_SAVE_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \ + } \ + } while (0) +#define ADAPT_UNINSTALL_AND_RESTORE_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \ + __module->previous_##__api = NULL; \ + __module->previous_##__api##_module = NULL; \ + } \ + } while (0) /* * Init module on the communicator @@ -111,12 +128,25 @@ static int adapt_module_enable(mca_coll_base_module_t * module, { mca_coll_adapt_module_t * adapt_module = (mca_coll_adapt_module_t*) module; - ADAPT_SAVE_PREV_COLL_API(reduce); - ADAPT_SAVE_PREV_COLL_API(ireduce); + ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, reduce); + ADAPT_INSTALL_COLL_API(comm, adapt_module, bcast); + ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, ireduce); + ADAPT_INSTALL_COLL_API(comm, adapt_module, ibcast); return OMPI_SUCCESS; } +static int adapt_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_adapt_module_t *adapt_module = (mca_coll_adapt_module_t *)module; + ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, reduce); + ADAPT_UNINSTALL_COLL_API(comm, adapt_module, bcast); + ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, ireduce); + ADAPT_UNINSTALL_COLL_API(comm, adapt_module, ibcast); + + return OMPI_SUCCESS; +} /* * Initial query function that is invoked during MPI_INIT, allowing * this component to disqualify itself if it doesn't support the @@ -165,24 +195,11 @@ mca_coll_base_module_t *ompi_coll_adapt_comm_query(struct ompi_communicator_t * /* All is good -- return a module */ adapt_module->super.coll_module_enable = adapt_module_enable; - adapt_module->super.coll_allgather = NULL; - adapt_module->super.coll_allgatherv = NULL; - adapt_module->super.coll_allreduce = NULL; - adapt_module->super.coll_alltoall = NULL; - adapt_module->super.coll_alltoallw = NULL; - adapt_module->super.coll_barrier = NULL; + adapt_module->super.coll_module_disable = adapt_module_disable; adapt_module->super.coll_bcast = ompi_coll_adapt_bcast; - adapt_module->super.coll_exscan = NULL; - adapt_module->super.coll_gather = NULL; - adapt_module->super.coll_gatherv = NULL; adapt_module->super.coll_reduce = ompi_coll_adapt_reduce; - adapt_module->super.coll_reduce_scatter = NULL; - adapt_module->super.coll_scan = NULL; - adapt_module->super.coll_scatter = NULL; - adapt_module->super.coll_scatterv = NULL; adapt_module->super.coll_ibcast = ompi_coll_adapt_ibcast; adapt_module->super.coll_ireduce = ompi_coll_adapt_ireduce; - adapt_module->super.coll_iallreduce = NULL; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "coll:adapt:comm_query (%s/%s): pick me! pick me!", diff --git a/ompi/mca/coll/base/coll_base_comm_select.c b/ompi/mca/coll/base/coll_base_comm_select.c index 83175aedc7d..24177501cec 100644 --- a/ompi/mca/coll/base/coll_base_comm_select.c +++ b/ompi/mca/coll/base/coll_base_comm_select.c @@ -22,6 +22,7 @@ * Copyright (c) 2016-2017 IBM Corporation. All rights reserved. * Copyright (c) 2017 FUJITSU LIMITED. All rights reserved. * Copyright (c) 2020 BULL S.A.S. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -41,6 +42,7 @@ #include "opal/util/argv.h" #include "opal/util/show_help.h" #include "opal/class/opal_list.h" +#include "opal/class/opal_hash_table.h" #include "opal/class/opal_object.h" #include "ompi/mca/mca.h" #include "opal/mca/base/base.h" @@ -71,21 +73,138 @@ static int query_2_4_0(const mca_coll_base_component_2_4_0_t * int *priority, mca_coll_base_module_t ** module); -#define COPY(module, comm, func) \ - do { \ - if (NULL != module->coll_ ## func) { \ - if (NULL != comm->c_coll->coll_ ## func ## _module) { \ - OBJ_RELEASE(comm->c_coll->coll_ ## func ## _module); \ - } \ - comm->c_coll->coll_ ## func = module->coll_ ## func; \ - comm->c_coll->coll_ ## func ## _module = module; \ - OBJ_RETAIN(module); \ - } \ - } while (0) - #define CHECK_NULL(what, comm, func) \ ( (what) = # func , NULL == (comm)->c_coll->coll_ ## func) +static void mca_coll_base_get_component_name(ompi_communicator_t *comm, void* module, char** name) +{ + mca_coll_base_avail_coll_t *avail; + + *name = NULL; + OPAL_LIST_FOREACH(avail, comm->c_coll->module_list, mca_coll_base_avail_coll_t) { + if (avail->ac_module == module) { + *name = (char*) avail->ac_component_name; + break; + } + } +} + +#define PRINT_NAME(comm, func, func_name) \ + do { \ + char *name; \ + mca_coll_base_get_component_name(comm, (void*)comm->c_coll->coll_ ## func ## _module, &name); \ + opal_output_verbose(10, ompi_coll_base_framework.framework_output, \ + "coll:base:comm_select: communicator %s rank %d %s -> %s", comm->c_name, comm->c_my_rank, func_name, name); \ + } while (0); + +#define PRINT_ALL_BLOCKING(comm) \ + do { \ + PRINT_NAME(comm, allgather, "allgather"); \ + PRINT_NAME(comm, allgatherv, "allgatherv"); \ + PRINT_NAME(comm, allreduce, "allreduce"); \ + PRINT_NAME(comm, alltoall, "alltoall"); \ + PRINT_NAME(comm, alltoallv, "alltoallv"); \ + PRINT_NAME(comm, alltoallw, "alltoallw"); \ + PRINT_NAME(comm, barrier, "barrier"); \ + PRINT_NAME(comm, bcast, "bcast"); \ + PRINT_NAME(comm, exscan, "exscan"); \ + PRINT_NAME(comm, gather, "gather"); \ + PRINT_NAME(comm, gatherv, "gatherv"); \ + PRINT_NAME(comm, reduce, "reduce"); \ + PRINT_NAME(comm, reduce_scatter_block, "reduce_scatter_block"); \ + PRINT_NAME(comm, reduce_scatter, "reduce_scatter"); \ + PRINT_NAME(comm, scan, "scan"); \ + PRINT_NAME(comm, scatter, "scatter"); \ + PRINT_NAME(comm, scatterv, "scatterv"); \ + PRINT_NAME(comm, neighbor_allgather, "neighbor_allgather"); \ + PRINT_NAME(comm, neighbor_allgatherv, "neighbor_allgatherv"); \ + PRINT_NAME(comm, neighbor_alltoall, "neighbor_alltoall"); \ + PRINT_NAME(comm, neighbor_alltoallv, "neighbor_alltoallv"); \ + PRINT_NAME(comm, neighbor_alltoallw, "neighbor_alltoallw"); \ + PRINT_NAME(comm, reduce_local, "reduce_local"); \ + } while (0); + +#define PRINT_ALL_NB(comm) \ + do { \ + PRINT_NAME(comm, iallgather, "iallgather"); \ + PRINT_NAME(comm, iallgatherv, "iallgatherv");\ + PRINT_NAME(comm, iallreduce, "iallreduce"); \ + PRINT_NAME(comm, ialltoall, "ialltoall"); \ + PRINT_NAME(comm, ialltoallv, "ialltoallv"); \ + PRINT_NAME(comm, ialltoallw, "ialltoallw"); \ + PRINT_NAME(comm, ibarrier, "ibarrier"); \ + PRINT_NAME(comm, ibcast, "ibcast"); \ + PRINT_NAME(comm, iexscan, "iexscan"); \ + PRINT_NAME(comm, igather, "igather"); \ + PRINT_NAME(comm, igatherv, "igatherv"); \ + PRINT_NAME(comm, ireduce, "ireduce"); \ + PRINT_NAME(comm, ireduce_scatter_block, "ireduce_scatter_block"); \ + PRINT_NAME(comm, ireduce_scatter, "ireduce_scatter"); \ + PRINT_NAME(comm, iscan, "iscan"); \ + PRINT_NAME(comm, iscatter, "iscatter"); \ + PRINT_NAME(comm, iscatterv, "iscatterv"); \ + PRINT_NAME(comm, ineighbor_allgather, "ineighbor_allgather"); \ + PRINT_NAME(comm, ineighbor_allgatherv, "ineighbor_allgatherv"); \ + PRINT_NAME(comm, ineighbor_alltoall, "ineighbor_alltoall"); \ + PRINT_NAME(comm, ineighbor_alltoallv, "ineighbor_alltoallv"); \ + PRINT_NAME(comm, ineighbor_alltoallw, "ineighbor_alltoallw"); \ + } while (0); + +#define PRINT_ALL_PERSISTENT(comm) \ + do { \ + PRINT_NAME(comm, allgather_init, "allgather_init"); \ + PRINT_NAME(comm, allgatherv_init, "allgatherv_init"); \ + PRINT_NAME(comm, allreduce_init, "allreduce_init"); \ + PRINT_NAME(comm, alltoall_init, "alltoall_init"); \ + PRINT_NAME(comm, alltoallv_init, "alltoallv_init"); \ + PRINT_NAME(comm, alltoallw_init, "alltoallw_init"); \ + PRINT_NAME(comm, barrier_init, "barrier_init"); \ + PRINT_NAME(comm, bcast_init, "bcast_init"); \ + PRINT_NAME(comm, exscan_init, "exscan_init"); \ + PRINT_NAME(comm, gather_init, "gather_init"); \ + PRINT_NAME(comm, gatherv_init, "gatherv_init"); \ + PRINT_NAME(comm, reduce_init, "reduce_init"); \ + PRINT_NAME(comm, reduce_scatter_block_init, "reduce_scatter_block_init"); \ + PRINT_NAME(comm, reduce_scatter_init, "reduce_scatter_init"); \ + PRINT_NAME(comm, scan_init, "scan_init"); \ + PRINT_NAME(comm, scatter_init, "scatter_init"); \ + PRINT_NAME(comm, scatterv_init, "scatterv_init"); \ + PRINT_NAME(comm, neighbor_allgather_init, "neighbor_allgather_init"); \ + PRINT_NAME(comm, neighbor_allgatherv_init, "neighbor_allgatherv_init"); \ + PRINT_NAME(comm, neighbor_alltoall_init, "neighbor_alltoall_init"); \ + PRINT_NAME(comm, neighbor_alltoallv_init, "neighbor_alltoallv_init"); \ + PRINT_NAME(comm, neighbor_alltoallw_init, "neighbor_alltoallw_init"); \ + } while (0); + +#define PRINT_ALL_FT(comm) \ + do { \ + PRINT_NAME(comm, agree, "agree"); \ + PRINT_NAME(comm, iagree, "iagree"); \ + } while (0); + +static void mca_coll_base_print_component_names(ompi_communicator_t *comm) +{ + /* + ** Verbosity level 1 - 19 will only print the blocking and non-blocking collectives + ** assigned to MPI_COMM_WORLD, but not the persistent and ft ones. + ** + ** Verbosity level 20 will print all blocking and non-blocking collectives for all communicators, + ** but not the persistent and ft ones. + ** + ** Verbosity level > 20 will print all collectives for all communicators. + */ + if ( (MPI_COMM_WORLD == comm) || (ompi_coll_base_framework.framework_verbose >= 20)) { + PRINT_ALL_BLOCKING (comm); + PRINT_ALL_NB (comm); + if (ompi_coll_base_framework.framework_verbose > 20) { + PRINT_ALL_PERSISTENT (comm); +#if OPAL_ENABLE_FT_MPI + PRINT_ALL_FT (comm); +#endif + } + } +} + /* * This function is called at the initialization time of every * communicator. It is used to select which coll component will be @@ -134,7 +253,6 @@ int mca_coll_base_comm_select(ompi_communicator_t * comm) NULL != item; item = opal_list_remove_first(selectable)) { mca_coll_base_avail_coll_t *avail = (mca_coll_base_avail_coll_t *) item; - /* initialize the module */ ret = avail->ac_module->coll_module_enable(avail->ac_module, comm); @@ -147,95 +265,12 @@ int mca_coll_base_comm_select(ompi_communicator_t * comm) /* Save every component that is initialized, * queried and enabled successfully */ opal_list_append(comm->c_coll->module_list, &avail->super); - - /* copy over any of the pointers */ - COPY(avail->ac_module, comm, allgather); - COPY(avail->ac_module, comm, allgatherv); - COPY(avail->ac_module, comm, allreduce); - COPY(avail->ac_module, comm, alltoall); - COPY(avail->ac_module, comm, alltoallv); - COPY(avail->ac_module, comm, alltoallw); - COPY(avail->ac_module, comm, barrier); - COPY(avail->ac_module, comm, bcast); - COPY(avail->ac_module, comm, exscan); - COPY(avail->ac_module, comm, gather); - COPY(avail->ac_module, comm, gatherv); - COPY(avail->ac_module, comm, reduce); - COPY(avail->ac_module, comm, reduce_scatter_block); - COPY(avail->ac_module, comm, reduce_scatter); - COPY(avail->ac_module, comm, scan); - COPY(avail->ac_module, comm, scatter); - COPY(avail->ac_module, comm, scatterv); - - COPY(avail->ac_module, comm, iallgather); - COPY(avail->ac_module, comm, iallgatherv); - COPY(avail->ac_module, comm, iallreduce); - COPY(avail->ac_module, comm, ialltoall); - COPY(avail->ac_module, comm, ialltoallv); - COPY(avail->ac_module, comm, ialltoallw); - COPY(avail->ac_module, comm, ibarrier); - COPY(avail->ac_module, comm, ibcast); - COPY(avail->ac_module, comm, iexscan); - COPY(avail->ac_module, comm, igather); - COPY(avail->ac_module, comm, igatherv); - COPY(avail->ac_module, comm, ireduce); - COPY(avail->ac_module, comm, ireduce_scatter_block); - COPY(avail->ac_module, comm, ireduce_scatter); - COPY(avail->ac_module, comm, iscan); - COPY(avail->ac_module, comm, iscatter); - COPY(avail->ac_module, comm, iscatterv); - - COPY(avail->ac_module, comm, allgather_init); - COPY(avail->ac_module, comm, allgatherv_init); - COPY(avail->ac_module, comm, allreduce_init); - COPY(avail->ac_module, comm, alltoall_init); - COPY(avail->ac_module, comm, alltoallv_init); - COPY(avail->ac_module, comm, alltoallw_init); - COPY(avail->ac_module, comm, barrier_init); - COPY(avail->ac_module, comm, bcast_init); - COPY(avail->ac_module, comm, exscan_init); - COPY(avail->ac_module, comm, gather_init); - COPY(avail->ac_module, comm, gatherv_init); - COPY(avail->ac_module, comm, reduce_init); - COPY(avail->ac_module, comm, reduce_scatter_block_init); - COPY(avail->ac_module, comm, reduce_scatter_init); - COPY(avail->ac_module, comm, scan_init); - COPY(avail->ac_module, comm, scatter_init); - COPY(avail->ac_module, comm, scatterv_init); - - /* We can not reliably check if this comm has a topology - * at this time. The flags are set *after* coll_select */ - COPY(avail->ac_module, comm, neighbor_allgather); - COPY(avail->ac_module, comm, neighbor_allgatherv); - COPY(avail->ac_module, comm, neighbor_alltoall); - COPY(avail->ac_module, comm, neighbor_alltoallv); - COPY(avail->ac_module, comm, neighbor_alltoallw); - - COPY(avail->ac_module, comm, ineighbor_allgather); - COPY(avail->ac_module, comm, ineighbor_allgatherv); - COPY(avail->ac_module, comm, ineighbor_alltoall); - COPY(avail->ac_module, comm, ineighbor_alltoallv); - COPY(avail->ac_module, comm, ineighbor_alltoallw); - - COPY(avail->ac_module, comm, neighbor_allgather_init); - COPY(avail->ac_module, comm, neighbor_allgatherv_init); - COPY(avail->ac_module, comm, neighbor_alltoall_init); - COPY(avail->ac_module, comm, neighbor_alltoallv_init); - COPY(avail->ac_module, comm, neighbor_alltoallw_init); - - COPY(avail->ac_module, comm, reduce_local); - -#if OPAL_ENABLE_FT_MPI - COPY(avail->ac_module, comm, agree); - COPY(avail->ac_module, comm, iagree); -#endif } else { /* release the original module reference and the list item */ OBJ_RELEASE(avail->ac_module); OBJ_RELEASE(avail); } } - /* Done with the list from the check_components() call so release it. */ OBJ_RELEASE(selectable); @@ -291,6 +326,10 @@ int mca_coll_base_comm_select(ompi_communicator_t * comm) ((OMPI_COMM_IS_INTRA(comm)) && CHECK_NULL(which_func, comm, scan_init)) || CHECK_NULL(which_func, comm, scatter_init) || CHECK_NULL(which_func, comm, scatterv_init) || +#if OPAL_ENABLE_FT_MPI + CHECK_NULL(which_func, comm, agree) || + CHECK_NULL(which_func, comm, iagree) || +#endif /* OPAL_ENABLE_FT_MPI */ CHECK_NULL(which_func, comm, reduce_local) ) { /* TODO -- Once the topology flags are set before coll_select then * check if neighborhood collectives have been set. */ @@ -298,9 +337,14 @@ int mca_coll_base_comm_select(ompi_communicator_t * comm) opal_show_help("help-mca-coll-base.txt", "comm-select:no-function-available", true, which_func); - mca_coll_base_comm_unselect(comm); + mca_coll_base_comm_unselect(comm); return OMPI_ERR_NOT_FOUND; } + + if (ompi_coll_base_framework.framework_verbose > 0) { + mca_coll_base_print_component_names(comm); + } + return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/base/coll_base_comm_unselect.c b/ompi/mca/coll/base/coll_base_comm_unselect.c index 9608b498cba..aca7bb23095 100644 --- a/ompi/mca/coll/base/coll_base_comm_unselect.c +++ b/ompi/mca/coll/base/coll_base_comm_unselect.c @@ -17,6 +17,7 @@ * Copyright (c) 2017 IBM Corporation. All rights reserved. * Copyright (c) 2017 FUJITSU LIMITED. All rights reserved. * Copyright (c) 2020 BULL S.A.S. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -37,113 +38,120 @@ #include "ompi/mca/coll/base/base.h" #include "ompi/mca/coll/base/coll_base_util.h" -#define CLOSE(comm, func) \ +#if OPAL_ENABLE_DEBUG +#define CHECK_CLEAN_COLL(comm, func) \ do { \ - if (NULL != comm->c_coll->coll_ ## func ## _module) { \ - if (NULL != comm->c_coll->coll_ ## func ## _module->coll_module_disable) { \ - comm->c_coll->coll_ ## func ## _module->coll_module_disable( \ - comm->c_coll->coll_ ## func ## _module, comm); \ - } \ - OBJ_RELEASE(comm->c_coll->coll_ ## func ## _module); \ - comm->c_coll->coll_## func = NULL; \ - comm->c_coll->coll_## func ## _module = NULL; \ + if (NULL != comm->c_coll->coll_ ## func ## _module || \ + NULL != comm->c_coll->coll_ ## func ) { \ + opal_output_verbose(10, ompi_coll_base_framework.framework_output, \ + "coll:base:comm_unselect: Comm %p (%s) has a left over %s collective during cleanup", \ + (void*)comm, comm->c_name, #func); \ } \ } while (0) +#else +#define CHECK_CLEAN_COLL(comm, func) +#endif /* OPAL_ENABLE_DEBUG */ int mca_coll_base_comm_unselect(ompi_communicator_t * comm) { opal_list_item_t *item; - CLOSE(comm, allgather); - CLOSE(comm, allgatherv); - CLOSE(comm, allreduce); - CLOSE(comm, alltoall); - CLOSE(comm, alltoallv); - CLOSE(comm, alltoallw); - CLOSE(comm, barrier); - CLOSE(comm, bcast); - CLOSE(comm, exscan); - CLOSE(comm, gather); - CLOSE(comm, gatherv); - CLOSE(comm, reduce); - CLOSE(comm, reduce_scatter_block); - CLOSE(comm, reduce_scatter); - CLOSE(comm, scan); - CLOSE(comm, scatter); - CLOSE(comm, scatterv); - - CLOSE(comm, iallgather); - CLOSE(comm, iallgatherv); - CLOSE(comm, iallreduce); - CLOSE(comm, ialltoall); - CLOSE(comm, ialltoallv); - CLOSE(comm, ialltoallw); - CLOSE(comm, ibarrier); - CLOSE(comm, ibcast); - CLOSE(comm, iexscan); - CLOSE(comm, igather); - CLOSE(comm, igatherv); - CLOSE(comm, ireduce); - CLOSE(comm, ireduce_scatter_block); - CLOSE(comm, ireduce_scatter); - CLOSE(comm, iscan); - CLOSE(comm, iscatter); - CLOSE(comm, iscatterv); - - CLOSE(comm, allgather_init); - CLOSE(comm, allgatherv_init); - CLOSE(comm, allreduce_init); - CLOSE(comm, alltoall_init); - CLOSE(comm, alltoallv_init); - CLOSE(comm, alltoallw_init); - CLOSE(comm, barrier_init); - CLOSE(comm, bcast_init); - CLOSE(comm, exscan_init); - CLOSE(comm, gather_init); - CLOSE(comm, gatherv_init); - CLOSE(comm, reduce_init); - CLOSE(comm, reduce_scatter_block_init); - CLOSE(comm, reduce_scatter_init); - CLOSE(comm, scan_init); - CLOSE(comm, scatter_init); - CLOSE(comm, scatterv_init); - - CLOSE(comm, neighbor_allgather); - CLOSE(comm, neighbor_allgatherv); - CLOSE(comm, neighbor_alltoall); - CLOSE(comm, neighbor_alltoallv); - CLOSE(comm, neighbor_alltoallw); - - CLOSE(comm, ineighbor_allgather); - CLOSE(comm, ineighbor_allgatherv); - CLOSE(comm, ineighbor_alltoall); - CLOSE(comm, ineighbor_alltoallv); - CLOSE(comm, ineighbor_alltoallw); - - CLOSE(comm, neighbor_allgather_init); - CLOSE(comm, neighbor_allgatherv_init); - CLOSE(comm, neighbor_alltoall_init); - CLOSE(comm, neighbor_alltoallv_init); - CLOSE(comm, neighbor_alltoallw_init); - - CLOSE(comm, reduce_local); - -#if OPAL_ENABLE_FT_MPI - CLOSE(comm, agree); - CLOSE(comm, iagree); -#endif - - for (item = opal_list_remove_first(comm->c_coll->module_list); - NULL != item; item = opal_list_remove_first(comm->c_coll->module_list)) { + /* Call module disable in the reverse order in which enable has been called + * in order to allow the modules to properly chain themselves. + */ + for (item = opal_list_remove_last(comm->c_coll->module_list); + NULL != item; item = opal_list_remove_last(comm->c_coll->module_list)) { mca_coll_base_avail_coll_t *avail = (mca_coll_base_avail_coll_t *) item; if(avail->ac_module) { + if (NULL != avail->ac_module->coll_module_disable ) { + avail->ac_module->coll_module_disable(avail->ac_module, comm); + } OBJ_RELEASE(avail->ac_module); } OBJ_RELEASE(avail); } OBJ_RELEASE(comm->c_coll->module_list); + CHECK_CLEAN_COLL(comm, allgather); + CHECK_CLEAN_COLL(comm, allgatherv); + CHECK_CLEAN_COLL(comm, allreduce); + CHECK_CLEAN_COLL(comm, alltoall); + CHECK_CLEAN_COLL(comm, alltoallv); + CHECK_CLEAN_COLL(comm, alltoallw); + CHECK_CLEAN_COLL(comm, barrier); + CHECK_CLEAN_COLL(comm, bcast); + CHECK_CLEAN_COLL(comm, exscan); + CHECK_CLEAN_COLL(comm, gather); + CHECK_CLEAN_COLL(comm, gatherv); + CHECK_CLEAN_COLL(comm, reduce); + CHECK_CLEAN_COLL(comm, reduce_scatter_block); + CHECK_CLEAN_COLL(comm, reduce_scatter); + CHECK_CLEAN_COLL(comm, scan); + CHECK_CLEAN_COLL(comm, scatter); + CHECK_CLEAN_COLL(comm, scatterv); + + CHECK_CLEAN_COLL(comm, iallgather); + CHECK_CLEAN_COLL(comm, iallgatherv); + CHECK_CLEAN_COLL(comm, iallreduce); + CHECK_CLEAN_COLL(comm, ialltoall); + CHECK_CLEAN_COLL(comm, ialltoallv); + CHECK_CLEAN_COLL(comm, ialltoallw); + CHECK_CLEAN_COLL(comm, ibarrier); + CHECK_CLEAN_COLL(comm, ibcast); + CHECK_CLEAN_COLL(comm, iexscan); + CHECK_CLEAN_COLL(comm, igather); + CHECK_CLEAN_COLL(comm, igatherv); + CHECK_CLEAN_COLL(comm, ireduce); + CHECK_CLEAN_COLL(comm, ireduce_scatter_block); + CHECK_CLEAN_COLL(comm, ireduce_scatter); + CHECK_CLEAN_COLL(comm, iscan); + CHECK_CLEAN_COLL(comm, iscatter); + CHECK_CLEAN_COLL(comm, iscatterv); + + CHECK_CLEAN_COLL(comm, allgather_init); + CHECK_CLEAN_COLL(comm, allgatherv_init); + CHECK_CLEAN_COLL(comm, allreduce_init); + CHECK_CLEAN_COLL(comm, alltoall_init); + CHECK_CLEAN_COLL(comm, alltoallv_init); + CHECK_CLEAN_COLL(comm, alltoallw_init); + CHECK_CLEAN_COLL(comm, barrier_init); + CHECK_CLEAN_COLL(comm, bcast_init); + CHECK_CLEAN_COLL(comm, exscan_init); + CHECK_CLEAN_COLL(comm, gather_init); + CHECK_CLEAN_COLL(comm, gatherv_init); + CHECK_CLEAN_COLL(comm, reduce_init); + CHECK_CLEAN_COLL(comm, reduce_scatter_block_init); + CHECK_CLEAN_COLL(comm, reduce_scatter_init); + CHECK_CLEAN_COLL(comm, scan_init); + CHECK_CLEAN_COLL(comm, scatter_init); + CHECK_CLEAN_COLL(comm, scatterv_init); + + CHECK_CLEAN_COLL(comm, neighbor_allgather); + CHECK_CLEAN_COLL(comm, neighbor_allgatherv); + CHECK_CLEAN_COLL(comm, neighbor_alltoall); + CHECK_CLEAN_COLL(comm, neighbor_alltoallv); + CHECK_CLEAN_COLL(comm, neighbor_alltoallw); + + CHECK_CLEAN_COLL(comm, ineighbor_allgather); + CHECK_CLEAN_COLL(comm, ineighbor_allgatherv); + CHECK_CLEAN_COLL(comm, ineighbor_alltoall); + CHECK_CLEAN_COLL(comm, ineighbor_alltoallv); + CHECK_CLEAN_COLL(comm, ineighbor_alltoallw); + + CHECK_CLEAN_COLL(comm, neighbor_allgather_init); + CHECK_CLEAN_COLL(comm, neighbor_allgatherv_init); + CHECK_CLEAN_COLL(comm, neighbor_alltoall_init); + CHECK_CLEAN_COLL(comm, neighbor_alltoallv_init); + CHECK_CLEAN_COLL(comm, neighbor_alltoallw_init); + + CHECK_CLEAN_COLL(comm, reduce_local); + +#if OPAL_ENABLE_FT_MPI + CHECK_CLEAN_COLL(comm, agree); + CHECK_CLEAN_COLL(comm, iagree); +#endif + free(comm->c_coll); comm->c_coll = NULL; diff --git a/ompi/mca/coll/base/help-mca-coll-base.txt b/ompi/mca/coll/base/help-mca-coll-base.txt index d6e0071fa7a..f5afb0ce19b 100644 --- a/ompi/mca/coll/base/help-mca-coll-base.txt +++ b/ompi/mca/coll/base/help-mca-coll-base.txt @@ -10,7 +10,8 @@ # University of Stuttgart. All rights reserved. # Copyright (c) 2004-2005 The Regents of the University of California. # All rights reserved. -# Copyright (c) 2015 Cisco Systems, Inc. All rights reserved. +# Copyright (c) 2015 Cisco Systems, Inc. All rights reserved. +# Copyright (c) 2024 NVIDIA Corporation. All rights reserved. # $COPYRIGHT$ # # Additional copyrights may follow @@ -45,3 +46,10 @@ using it was destroyed. This is somewhat unusual: the module itself may be at fault, or this may be a symptom of another issue (e.g., a memory problem). +# +[comm-select:missing collective] +%s %s collective require support from a prior collective module for the fallback +case. Such support is missing in this run. +hostname: %s +priority: %d +# diff --git a/ompi/mca/coll/basic/coll_basic.h b/ompi/mca/coll/basic/coll_basic.h index 3cb36692e47..8e38655be00 100644 --- a/ompi/mca/coll/basic/coll_basic.h +++ b/ompi/mca/coll/basic/coll_basic.h @@ -53,9 +53,6 @@ BEGIN_C_DECLS *mca_coll_basic_comm_query(struct ompi_communicator_t *comm, int *priority); - int mca_coll_basic_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - int mca_coll_basic_allgather_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, diff --git a/ompi/mca/coll/basic/coll_basic_module.c b/ompi/mca/coll/basic/coll_basic_module.c index 6391b735b5d..9913a3cda26 100644 --- a/ompi/mca/coll/basic/coll_basic_module.c +++ b/ompi/mca/coll/basic/coll_basic_module.c @@ -36,6 +36,13 @@ #include "coll_basic.h" +static int +mca_coll_basic_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); +static int +mca_coll_basic_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); + /* * Initial query function that is invoked during MPI_INIT, allowing * this component to disqualify itself if it doesn't support the @@ -50,7 +57,6 @@ mca_coll_basic_init_query(bool enable_progress_threads, return OMPI_SUCCESS; } - /* * Invoked when there's a new communicator that has been created. * Look at the communicator and decide which set of functions and @@ -70,82 +76,19 @@ mca_coll_basic_comm_query(struct ompi_communicator_t *comm, /* Choose whether to use [intra|inter], and [linear|log]-based * algorithms. */ basic_module->super.coll_module_enable = mca_coll_basic_module_enable; - - if (OMPI_COMM_IS_INTER(comm)) { - basic_module->super.coll_allgather = mca_coll_basic_allgather_inter; - basic_module->super.coll_allgatherv = mca_coll_basic_allgatherv_inter; - basic_module->super.coll_allreduce = mca_coll_basic_allreduce_inter; - basic_module->super.coll_alltoall = mca_coll_basic_alltoall_inter; - basic_module->super.coll_alltoallv = mca_coll_basic_alltoallv_inter; - basic_module->super.coll_alltoallw = mca_coll_basic_alltoallw_inter; - basic_module->super.coll_barrier = mca_coll_basic_barrier_inter_lin; - basic_module->super.coll_bcast = mca_coll_basic_bcast_lin_inter; - basic_module->super.coll_exscan = NULL; - basic_module->super.coll_gather = mca_coll_basic_gather_inter; - basic_module->super.coll_gatherv = mca_coll_basic_gatherv_inter; - basic_module->super.coll_reduce = mca_coll_basic_reduce_lin_inter; - basic_module->super.coll_reduce_scatter_block = mca_coll_basic_reduce_scatter_block_inter; - basic_module->super.coll_reduce_scatter = mca_coll_basic_reduce_scatter_inter; - basic_module->super.coll_scan = NULL; - basic_module->super.coll_scatter = mca_coll_basic_scatter_inter; - basic_module->super.coll_scatterv = mca_coll_basic_scatterv_inter; - } else if (ompi_comm_size(comm) <= mca_coll_basic_crossover) { - basic_module->super.coll_allgather = ompi_coll_base_allgather_intra_basic_linear; - basic_module->super.coll_allgatherv = ompi_coll_base_allgatherv_intra_basic_default; - basic_module->super.coll_allreduce = mca_coll_basic_allreduce_intra; - basic_module->super.coll_alltoall = ompi_coll_base_alltoall_intra_basic_linear; - basic_module->super.coll_alltoallv = ompi_coll_base_alltoallv_intra_basic_linear; - basic_module->super.coll_alltoallw = mca_coll_basic_alltoallw_intra; - basic_module->super.coll_barrier = ompi_coll_base_barrier_intra_basic_linear; - basic_module->super.coll_bcast = ompi_coll_base_bcast_intra_basic_linear; - basic_module->super.coll_exscan = ompi_coll_base_exscan_intra_linear; - basic_module->super.coll_gather = ompi_coll_base_gather_intra_basic_linear; - basic_module->super.coll_gatherv = mca_coll_basic_gatherv_intra; - basic_module->super.coll_reduce = ompi_coll_base_reduce_intra_basic_linear; - basic_module->super.coll_reduce_scatter_block = mca_coll_basic_reduce_scatter_block_intra; - basic_module->super.coll_reduce_scatter = mca_coll_basic_reduce_scatter_intra; - basic_module->super.coll_scan = ompi_coll_base_scan_intra_linear; - basic_module->super.coll_scatter = ompi_coll_base_scatter_intra_basic_linear; - basic_module->super.coll_scatterv = mca_coll_basic_scatterv_intra; - } else { - basic_module->super.coll_allgather = ompi_coll_base_allgather_intra_basic_linear; - basic_module->super.coll_allgatherv = ompi_coll_base_allgatherv_intra_basic_default; - basic_module->super.coll_allreduce = mca_coll_basic_allreduce_intra; - basic_module->super.coll_alltoall = ompi_coll_base_alltoall_intra_basic_linear; - basic_module->super.coll_alltoallv = ompi_coll_base_alltoallv_intra_basic_linear; - basic_module->super.coll_alltoallw = mca_coll_basic_alltoallw_intra; - basic_module->super.coll_barrier = mca_coll_basic_barrier_intra_log; - basic_module->super.coll_bcast = mca_coll_basic_bcast_log_intra; - basic_module->super.coll_exscan = ompi_coll_base_exscan_intra_linear; - basic_module->super.coll_gather = ompi_coll_base_gather_intra_basic_linear; - basic_module->super.coll_gatherv = mca_coll_basic_gatherv_intra; - basic_module->super.coll_reduce = mca_coll_basic_reduce_log_intra; - basic_module->super.coll_reduce_scatter_block = mca_coll_basic_reduce_scatter_block_intra; - basic_module->super.coll_reduce_scatter = mca_coll_basic_reduce_scatter_intra; - basic_module->super.coll_scan = ompi_coll_base_scan_intra_linear; - basic_module->super.coll_scatter = ompi_coll_base_scatter_intra_basic_linear; - basic_module->super.coll_scatterv = mca_coll_basic_scatterv_intra; - } - - /* These functions will return an error code if comm does not have a virtual topology */ - basic_module->super.coll_neighbor_allgather = mca_coll_basic_neighbor_allgather; - basic_module->super.coll_neighbor_allgatherv = mca_coll_basic_neighbor_allgatherv; - basic_module->super.coll_neighbor_alltoall = mca_coll_basic_neighbor_alltoall; - basic_module->super.coll_neighbor_alltoallv = mca_coll_basic_neighbor_alltoallv; - basic_module->super.coll_neighbor_alltoallw = mca_coll_basic_neighbor_alltoallw; - - basic_module->super.coll_reduce_local = mca_coll_base_reduce_local; - -#if OPAL_ENABLE_FT_MPI - /* Default to some shim mappings over allreduce */ - basic_module->super.coll_agree = ompi_coll_base_agree_noft; - basic_module->super.coll_iagree = ompi_coll_base_iagree_noft; -#endif + basic_module->super.coll_module_disable = mca_coll_basic_module_disable; return &(basic_module->super); } +#define BASIC_INSTALL_COLL_API(__comm, __module, __api, __coll) \ + do \ + { \ + (__module)->super.coll_##__api = __coll; \ + MCA_COLL_INSTALL_API(__comm, __api, __coll, &__module->super, "basic"); \ + } while (0) + /* * Init module on the communicator */ @@ -153,12 +96,158 @@ int mca_coll_basic_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm) { + mca_coll_basic_module_t *basic_module = (mca_coll_basic_module_t*)module; + /* prepare the placeholder for the array of request* */ module->base_data = OBJ_NEW(mca_coll_base_comm_t); if (NULL == module->base_data) { return OMPI_ERROR; } + if (OMPI_COMM_IS_INTER(comm)) + { + BASIC_INSTALL_COLL_API(comm, basic_module, allgather, mca_coll_basic_allgather_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, allgatherv, mca_coll_basic_allgatherv_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, allreduce, mca_coll_basic_allreduce_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, alltoall, mca_coll_basic_alltoall_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, alltoallv, mca_coll_basic_alltoallv_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, alltoallw, mca_coll_basic_alltoallw_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, barrier, mca_coll_basic_barrier_inter_lin); + BASIC_INSTALL_COLL_API(comm, basic_module, bcast, mca_coll_basic_bcast_lin_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, gather, mca_coll_basic_gather_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, gatherv, mca_coll_basic_gatherv_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce, mca_coll_basic_reduce_lin_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce_scatter_block, mca_coll_basic_reduce_scatter_block_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce_scatter, mca_coll_basic_reduce_scatter_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, scatter, mca_coll_basic_scatter_inter); + BASIC_INSTALL_COLL_API(comm, basic_module, scatterv, mca_coll_basic_scatterv_inter); + } + else { + if (ompi_comm_size(comm) <= mca_coll_basic_crossover) { + BASIC_INSTALL_COLL_API(comm, basic_module, barrier, ompi_coll_base_barrier_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, bcast, ompi_coll_base_bcast_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce, ompi_coll_base_reduce_intra_basic_linear); + } else { + BASIC_INSTALL_COLL_API(comm, basic_module, barrier, mca_coll_basic_barrier_intra_log); + BASIC_INSTALL_COLL_API(comm, basic_module, bcast, mca_coll_basic_bcast_log_intra); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce, mca_coll_basic_reduce_log_intra); + } + BASIC_INSTALL_COLL_API(comm, basic_module, allgather, ompi_coll_base_allgather_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, allgatherv, ompi_coll_base_allgatherv_intra_basic_default); + BASIC_INSTALL_COLL_API(comm, basic_module, allreduce, mca_coll_basic_allreduce_intra); + BASIC_INSTALL_COLL_API(comm, basic_module, alltoall, ompi_coll_base_alltoall_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, alltoallv, ompi_coll_base_alltoallv_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, alltoallw, mca_coll_basic_alltoallw_intra); + BASIC_INSTALL_COLL_API(comm, basic_module, exscan, ompi_coll_base_exscan_intra_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, gather, ompi_coll_base_gather_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, gatherv, mca_coll_basic_gatherv_intra); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce_scatter_block, mca_coll_basic_reduce_scatter_block_intra); + BASIC_INSTALL_COLL_API(comm, basic_module, reduce_scatter, mca_coll_basic_reduce_scatter_intra); + BASIC_INSTALL_COLL_API(comm, basic_module, scan, ompi_coll_base_scan_intra_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, scatter, ompi_coll_base_scatter_intra_basic_linear); + BASIC_INSTALL_COLL_API(comm, basic_module, scatterv, mca_coll_basic_scatterv_intra); + } + /* These functions will return an error code if comm does not have a virtual topology */ + BASIC_INSTALL_COLL_API(comm, basic_module, neighbor_allgather, mca_coll_basic_neighbor_allgather); + BASIC_INSTALL_COLL_API(comm, basic_module, neighbor_allgatherv, mca_coll_basic_neighbor_allgatherv); + BASIC_INSTALL_COLL_API(comm, basic_module, neighbor_alltoall, mca_coll_basic_neighbor_alltoall); + BASIC_INSTALL_COLL_API(comm, basic_module, neighbor_alltoallv, mca_coll_basic_neighbor_alltoallv); + BASIC_INSTALL_COLL_API(comm, basic_module, neighbor_alltoallw, mca_coll_basic_neighbor_alltoallw); + + /* Default to some shim mappings over allreduce */ + BASIC_INSTALL_COLL_API(comm, basic_module, agree, ompi_coll_base_agree_noft); + BASIC_INSTALL_COLL_API(comm, basic_module, iagree, ompi_coll_base_iagree_noft); + + BASIC_INSTALL_COLL_API(comm, basic_module, reduce_local, mca_coll_base_reduce_local); + + /* All done */ + return OMPI_SUCCESS; +} + +#define BASIC_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super ) { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "basic"); \ + } \ + } while (0) + +int +mca_coll_basic_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_basic_module_t *basic_module = (mca_coll_basic_module_t*)module; + + if (OMPI_COMM_IS_INTER(comm)) + { + BASIC_UNINSTALL_COLL_API(comm, basic_module, allgather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, allgatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, allreduce); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoall); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoallv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoallw); + BASIC_UNINSTALL_COLL_API(comm, basic_module, barrier); + BASIC_UNINSTALL_COLL_API(comm, basic_module, bcast); + BASIC_UNINSTALL_COLL_API(comm, basic_module, gather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, gatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_scatter_block); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_scatter); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scatter); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scatterv); + } + else if (ompi_comm_size(comm) <= mca_coll_basic_crossover) + { + BASIC_UNINSTALL_COLL_API(comm, basic_module, allgather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, allgatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, allreduce); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoall); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoallv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoallw); + BASIC_UNINSTALL_COLL_API(comm, basic_module, barrier); + BASIC_UNINSTALL_COLL_API(comm, basic_module, bcast); + BASIC_UNINSTALL_COLL_API(comm, basic_module, exscan); + BASIC_UNINSTALL_COLL_API(comm, basic_module, gather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, gatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_scatter_block); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_scatter); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scan); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scatter); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scatterv); + } + else + { + BASIC_UNINSTALL_COLL_API(comm, basic_module, allgather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, allgatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, allreduce); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoall); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoallv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, alltoallw); + BASIC_UNINSTALL_COLL_API(comm, basic_module, barrier); + BASIC_UNINSTALL_COLL_API(comm, basic_module, bcast); + BASIC_UNINSTALL_COLL_API(comm, basic_module, exscan); + BASIC_UNINSTALL_COLL_API(comm, basic_module, gather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, gatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_scatter_block); + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_scatter); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scan); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scatter); + BASIC_UNINSTALL_COLL_API(comm, basic_module, scatterv); + } + + /* These functions will return an error code if comm does not have a virtual topology */ + BASIC_UNINSTALL_COLL_API(comm, basic_module, neighbor_allgather); + BASIC_UNINSTALL_COLL_API(comm, basic_module, neighbor_allgatherv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, neighbor_alltoall); + BASIC_UNINSTALL_COLL_API(comm, basic_module, neighbor_alltoallv); + BASIC_UNINSTALL_COLL_API(comm, basic_module, neighbor_alltoallw); + + BASIC_UNINSTALL_COLL_API(comm, basic_module, agree); + BASIC_UNINSTALL_COLL_API(comm, basic_module, iagree); + + BASIC_UNINSTALL_COLL_API(comm, basic_module, reduce_local); /* All done */ return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/coll.h b/ompi/mca/coll/coll.h index 4908ece2d60..dd9d8a42a38 100644 --- a/ompi/mca/coll/coll.h +++ b/ompi/mca/coll/coll.h @@ -20,6 +20,7 @@ * Copyright (c) 2016-2017 IBM Corporation. All rights reserved. * Copyright (c) 2017 FUJITSU LIMITED. All rights reserved. * Copyright (c) 2020 BULL S.A.S. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -80,6 +81,7 @@ #include "opal/mca/base/base.h" #include "ompi/request/request.h" +#include "ompi/mca/coll/base/base.h" BEGIN_C_DECLS @@ -817,6 +819,29 @@ typedef struct mca_coll_base_comm_coll_t mca_coll_base_comm_coll_t; /* ******************************************************************** */ +#define MCA_COLL_SAVE_API(__comm, __api, __fct, __bmodule, __c_name) \ + do \ + { \ + OPAL_OUTPUT_VERBOSE((50, ompi_coll_base_framework.framework_output, \ + "Save %s collective in comm %p (%s) %p into %s:%s", \ + #__api, (void *)__comm, __comm->c_name, \ + (void*)(uintptr_t)__comm->c_coll->coll_##__api, \ + __c_name, #__fct)); \ + __fct = __comm->c_coll->coll_##__api; \ + __bmodule = __comm->c_coll->coll_##__api##_module; \ + } while (0) + +#define MCA_COLL_INSTALL_API(__comm, __api, __fct, __bmodule, __c_name) \ + do \ + { \ + OPAL_OUTPUT_VERBOSE((50, ompi_coll_base_framework.framework_output, \ + "Replace %s collective in comm %p (%s) from %p to %s:%s(%p)", \ + #__api, (void *)__comm, __comm->c_name, \ + (void*)(uintptr_t)__comm->c_coll->coll_##__api, \ + __c_name, #__fct, (void*)(uintptr_t)__fct)); \ + __comm->c_coll->coll_##__api = __fct; \ + __comm->c_coll->coll_##__api##_module = __bmodule; \ + } while (0) END_C_DECLS diff --git a/ompi/mca/coll/demo/coll_demo.h b/ompi/mca/coll/demo/coll_demo.h index 20ebe728b7a..d9ee2b633e2 100644 --- a/ompi/mca/coll/demo/coll_demo.h +++ b/ompi/mca/coll/demo/coll_demo.h @@ -10,6 +10,7 @@ * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. * Copyright (c) 2008 Cisco Systems, Inc. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -30,7 +31,7 @@ BEGIN_C_DECLS /* Globally exported variables */ -OMPI_DECLSPEC extern const mca_coll_base_component_2_4_0_t mca_coll_demo_component; + OMPI_DECLSPEC extern const mca_coll_base_component_2_4_0_t mca_coll_demo_component; extern int mca_coll_demo_priority; extern int mca_coll_demo_verbose; @@ -39,87 +40,84 @@ OMPI_DECLSPEC extern const mca_coll_base_component_2_4_0_t mca_coll_demo_compone int mca_coll_demo_init_query(bool enable_progress_threads, bool enable_mpi_threads); -mca_coll_base_module_t * -mca_coll_demo_comm_query(struct ompi_communicator_t *comm, int *priority); + mca_coll_base_module_t * + mca_coll_demo_comm_query(struct ompi_communicator_t *comm, int *priority); /* Module functions */ -int mca_coll_demo_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - - int mca_coll_demo_allgather_intra(void *sbuf, int scount, + int mca_coll_demo_allgather_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_allgather_inter(void *sbuf, int scount, + int mca_coll_demo_allgather_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_allgatherv_intra(void *sbuf, int scount, + int mca_coll_demo_allgatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, - void * rbuf, int *rcounts, int *disps, + void * rbuf, const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_allgatherv_inter(void *sbuf, int scount, + int mca_coll_demo_allgatherv_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, - void * rbuf, int *rcounts, int *disps, + void * rbuf, const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_allreduce_intra(void *sbuf, void *rbuf, int count, + int mca_coll_demo_allreduce_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_allreduce_inter(void *sbuf, void *rbuf, int count, + int mca_coll_demo_allreduce_inter(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_alltoall_intra(void *sbuf, int scount, + int mca_coll_demo_alltoall_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void* rbuf, int rcount, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_alltoall_inter(void *sbuf, int scount, + int mca_coll_demo_alltoall_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void* rbuf, int rcount, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_alltoallv_intra(void *sbuf, int *scounts, int *sdisps, + int mca_coll_demo_alltoallv_intra(const void *sbuf, const int *scounts, const int *sdisps, struct ompi_datatype_t *sdtype, - void *rbuf, int *rcounts, int *rdisps, + void *rbuf, const int *rcounts, const int *rdisps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_alltoallv_inter(void *sbuf, int *scounts, int *sdisps, + int mca_coll_demo_alltoallv_inter(const void *sbuf, const int *scounts, const int *sdisps, struct ompi_datatype_t *sdtype, - void *rbuf, int *rcounts, int *rdisps, + void *rbuf, const int *rcounts, const int *rdisps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_alltoallw_intra(void *sbuf, int *scounts, int *sdisps, - struct ompi_datatype_t **sdtypes, - void *rbuf, int *rcounts, int *rdisps, - struct ompi_datatype_t **rdtypes, + int mca_coll_demo_alltoallw_intra(const void *sbuf, const int *scounts, const int *sdisps, + struct ompi_datatype_t *const *sdtypes, + void *rbuf, const int *rcounts, const int *rdisps, + struct ompi_datatype_t *const *rdtypes, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_alltoallw_inter(void *sbuf, int *scounts, int *sdisps, - struct ompi_datatype_t **sdtypes, - void *rbuf, int *rcounts, int *rdisps, - struct ompi_datatype_t **rdtypes, + int mca_coll_demo_alltoallw_inter(const void *sbuf, const int *scounts, const int *sdisps, + struct ompi_datatype_t *const *sdtypes, + void *rbuf, const int *rcounts, const int *rdisps, + struct ompi_datatype_t *const *rdtypes, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); @@ -139,109 +137,109 @@ int mca_coll_demo_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_exscan_intra(void *sbuf, void *rbuf, int count, + int mca_coll_demo_exscan_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_exscan_inter(void *sbuf, void *rbuf, int count, + int mca_coll_demo_exscan_inter(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_gather_intra(void *sbuf, int scount, + int mca_coll_demo_gather_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_gather_inter(void *sbuf, int scount, + int mca_coll_demo_gather_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_gatherv_intra(void *sbuf, int scount, + int mca_coll_demo_gatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, - int *rcounts, int *disps, + const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_gatherv_inter(void *sbuf, int scount, + int mca_coll_demo_gatherv_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, - int *rcounts, int *disps, + const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_reduce_intra(void *sbuf, void* rbuf, int count, + int mca_coll_demo_reduce_intra(const void *sbuf, void* rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_reduce_inter(void *sbuf, void* rbuf, int count, + int mca_coll_demo_reduce_inter(const void *sbuf, void* rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_reduce_scatter_intra(void *sbuf, void *rbuf, - int *rcounts, + int mca_coll_demo_reduce_scatter_intra(const void *sbuf, void *rbuf, + const int *rcounts, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_reduce_scatter_inter(void *sbuf, void *rbuf, - int *rcounts, + int mca_coll_demo_reduce_scatter_inter(const void *sbuf, void *rbuf, + const int *rcounts, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_scan_intra(void *sbuf, void *rbuf, int count, + int mca_coll_demo_scan_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_scan_inter(void *sbuf, void *rbuf, int count, + int mca_coll_demo_scan_inter(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_scatter_intra(void *sbuf, int scount, + int mca_coll_demo_scatter_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_scatter_inter(void *sbuf, int scount, + int mca_coll_demo_scatter_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_scatterv_intra(void *sbuf, int *scounts, int *disps, + int mca_coll_demo_scatterv_intra(const void *sbuf, const int *scounts, const int *disps, struct ompi_datatype_t *sdtype, void* rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); - int mca_coll_demo_scatterv_inter(void *sbuf, int *scounts, int *disps, + int mca_coll_demo_scatterv_inter(const void *sbuf, const int *scounts, const int *disps, struct ompi_datatype_t *sdtype, void* rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); -struct mca_coll_demo_module_t { - mca_coll_base_module_t super; + struct mca_coll_demo_module_t { + mca_coll_base_module_t super; - mca_coll_base_comm_coll_t underlying; -}; -typedef struct mca_coll_demo_module_t mca_coll_demo_module_t; -OBJ_CLASS_DECLARATION(mca_coll_demo_module_t); + mca_coll_base_comm_coll_t c_coll; + }; + typedef struct mca_coll_demo_module_t mca_coll_demo_module_t; + OBJ_CLASS_DECLARATION(mca_coll_demo_module_t); END_C_DECLS diff --git a/ompi/mca/coll/demo/coll_demo_allgather.c b/ompi/mca/coll/demo/coll_demo_allgather.c index f33042a71d5..c2f2a820d45 100644 --- a/ompi/mca/coll/demo/coll_demo_allgather.c +++ b/ompi/mca/coll/demo/coll_demo_allgather.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same as MPI_Allgather() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_allgather_intra(void *sbuf, int scount, +int mca_coll_demo_allgather_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, @@ -41,9 +42,9 @@ int mca_coll_demo_allgather_intra(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo allgather_intra"); - return demo_module->underlying.coll_allgather(sbuf, scount, sdtype, rbuf, - rcount, rdtype, comm, - demo_module->underlying.coll_allgather_module); + return demo_module->c_coll.coll_allgather(sbuf, scount, sdtype, rbuf, + rcount, rdtype, comm, + demo_module->c_coll.coll_allgather_module); } @@ -54,7 +55,7 @@ int mca_coll_demo_allgather_intra(void *sbuf, int scount, * Accepts: - same as MPI_Allgather() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_allgather_inter(void *sbuf, int scount, +int mca_coll_demo_allgather_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -63,7 +64,7 @@ int mca_coll_demo_allgather_inter(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo allgather_inter"); - return demo_module->underlying.coll_allgather(sbuf, scount, sdtype, rbuf, - rcount, rdtype, comm, - demo_module->underlying.coll_allgather_module); + return demo_module->c_coll.coll_allgather(sbuf, scount, sdtype, rbuf, + rcount, rdtype, comm, + demo_module->c_coll.coll_allgather_module); } diff --git a/ompi/mca/coll/demo/coll_demo_allgatherv.c b/ompi/mca/coll/demo/coll_demo_allgatherv.c index b6503ec6865..1da300c3879 100644 --- a/ompi/mca/coll/demo/coll_demo_allgatherv.c +++ b/ompi/mca/coll/demo/coll_demo_allgatherv.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,19 +34,19 @@ * Accepts: - same as MPI_Allgatherv() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_allgatherv_intra(void *sbuf, int scount, +int mca_coll_demo_allgatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, - void * rbuf, int *rcounts, int *disps, + void * rbuf, const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo allgatherv_intra"); - return demo_module->underlying.coll_allgatherv(sbuf, scount, sdtype, - rbuf, rcounts, disps, - rdtype, comm, - demo_module->underlying.coll_allgatherv_module); + return demo_module->c_coll.coll_allgatherv(sbuf, scount, sdtype, + rbuf, rcounts, disps, + rdtype, comm, + demo_module->c_coll.coll_allgatherv_module); } @@ -56,17 +57,17 @@ int mca_coll_demo_allgatherv_intra(void *sbuf, int scount, * Accepts: - same as MPI_Allgatherv() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_allgatherv_inter(void *sbuf, int scount, - struct ompi_datatype_t *sdtype, - void * rbuf, int *rcounts, int *disps, - struct ompi_datatype_t *rdtype, +int mca_coll_demo_allgatherv_inter(const void *sbuf, int scount, + struct ompi_datatype_t *sdtype, + void * rbuf, const int *rcounts, const int *disps, + struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo allgatherv_inter"); - return demo_module->underlying.coll_allgatherv(sbuf, scount, sdtype, - rbuf, rcounts, disps, - rdtype, comm, - demo_module->underlying.coll_allgatherv_module); + return demo_module->c_coll.coll_allgatherv(sbuf, scount, sdtype, + rbuf, rcounts, disps, + rdtype, comm, + demo_module->c_coll.coll_allgatherv_module); } diff --git a/ompi/mca/coll/demo/coll_demo_allreduce.c b/ompi/mca/coll/demo/coll_demo_allreduce.c index 15975bacb1c..f8f833b1de1 100644 --- a/ompi/mca/coll/demo/coll_demo_allreduce.c +++ b/ompi/mca/coll/demo/coll_demo_allreduce.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same as MPI_Allreduce() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_allreduce_intra(void *sbuf, void *rbuf, int count, +int mca_coll_demo_allreduce_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -41,9 +42,9 @@ int mca_coll_demo_allreduce_intra(void *sbuf, void *rbuf, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo allreduce_intra"); - return demo_module->underlying.coll_allreduce(sbuf, rbuf, count, dtype, - op, comm, - demo_module->underlying.coll_allreduce_module); + return demo_module->c_coll.coll_allreduce(sbuf, rbuf, count, dtype, + op, comm, + demo_module->c_coll.coll_allreduce_module); } @@ -54,7 +55,7 @@ int mca_coll_demo_allreduce_intra(void *sbuf, void *rbuf, int count, * Accepts: - same as MPI_Allreduce() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_allreduce_inter(void *sbuf, void *rbuf, int count, +int mca_coll_demo_allreduce_inter(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -62,7 +63,7 @@ int mca_coll_demo_allreduce_inter(void *sbuf, void *rbuf, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo allreduce_inter"); - return demo_module->underlying.coll_allreduce(sbuf, rbuf, count, dtype, - op, comm, - demo_module->underlying.coll_allreduce_module); + return demo_module->c_coll.coll_allreduce(sbuf, rbuf, count, dtype, + op, comm, + demo_module->c_coll.coll_allreduce_module); } diff --git a/ompi/mca/coll/demo/coll_demo_alltoall.c b/ompi/mca/coll/demo/coll_demo_alltoall.c index d3559970121..41b5e7f4c3a 100644 --- a/ompi/mca/coll/demo/coll_demo_alltoall.c +++ b/ompi/mca/coll/demo/coll_demo_alltoall.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same as MPI_Alltoall() * Returns: - MPI_SUCCESS or an MPI error code */ -int mca_coll_demo_alltoall_intra(void *sbuf, int scount, +int mca_coll_demo_alltoall_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -42,10 +43,10 @@ int mca_coll_demo_alltoall_intra(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo alltoall_intra\n"); - return demo_module->underlying.coll_alltoall(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - comm, - demo_module->underlying.coll_alltoall_module); + return demo_module->c_coll.coll_alltoall(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + comm, + demo_module->c_coll.coll_alltoall_module); } @@ -56,7 +57,7 @@ int mca_coll_demo_alltoall_intra(void *sbuf, int scount, * Accepts: - same as MPI_Alltoall() * Returns: - MPI_SUCCESS or an MPI error code */ -int mca_coll_demo_alltoall_inter(void *sbuf, int scount, +int mca_coll_demo_alltoall_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -65,8 +66,8 @@ int mca_coll_demo_alltoall_inter(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo alltoall_inter\n"); - return demo_module->underlying.coll_alltoall(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - comm, - demo_module->underlying.coll_alltoall_module); + return demo_module->c_coll.coll_alltoall(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + comm, + demo_module->c_coll.coll_alltoall_module); } diff --git a/ompi/mca/coll/demo/coll_demo_alltoallv.c b/ompi/mca/coll/demo/coll_demo_alltoallv.c index 0e8cf13861b..512a3654fd2 100644 --- a/ompi/mca/coll/demo/coll_demo_alltoallv.c +++ b/ompi/mca/coll/demo/coll_demo_alltoallv.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -34,19 +35,19 @@ * Returns: - MPI_SUCCESS or an MPI error code */ int -mca_coll_demo_alltoallv_intra(void *sbuf, int *scounts, int *sdisps, +mca_coll_demo_alltoallv_intra(const void *sbuf, const int *scounts, const int *sdisps, struct ompi_datatype_t *sdtype, - void *rbuf, int *rcounts, int *rdisps, + void *rbuf, const int *rcounts, const int *rdisps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo alltoallv_intra"); - return demo_module->underlying.coll_alltoallv(sbuf, scounts, sdisps, - sdtype, rbuf, rcounts, - rdisps, rdtype, comm, - demo_module->underlying.coll_alltoallv_module); + return demo_module->c_coll.coll_alltoallv(sbuf, scounts, sdisps, + sdtype, rbuf, rcounts, + rdisps, rdtype, comm, + demo_module->c_coll.coll_alltoallv_module); } @@ -58,17 +59,17 @@ mca_coll_demo_alltoallv_intra(void *sbuf, int *scounts, int *sdisps, * Returns: - MPI_SUCCESS or an MPI error code */ int -mca_coll_demo_alltoallv_inter(void *sbuf, int *scounts, int *sdisps, +mca_coll_demo_alltoallv_inter(const void *sbuf, const int *scounts, const int *sdisps, struct ompi_datatype_t *sdtype, void *rbuf, - int *rcounts, int *rdisps, + const int *rcounts, const int *rdisps, struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo alltoallv_inter"); - return demo_module->underlying.coll_alltoallv(sbuf, scounts, sdisps, - sdtype, rbuf, rcounts, - rdisps, rdtype, comm, - demo_module->underlying.coll_alltoallv_module); + return demo_module->c_coll.coll_alltoallv(sbuf, scounts, sdisps, + sdtype, rbuf, rcounts, + rdisps, rdtype, comm, + demo_module->c_coll.coll_alltoallv_module); } diff --git a/ompi/mca/coll/demo/coll_demo_alltoallw.c b/ompi/mca/coll/demo/coll_demo_alltoallw.c index b9c29693178..de4bd8433ba 100644 --- a/ompi/mca/coll/demo/coll_demo_alltoallw.c +++ b/ompi/mca/coll/demo/coll_demo_alltoallw.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,19 +34,19 @@ * Accepts: - same as MPI_Alltoallw() * Returns: - MPI_SUCCESS or an MPI error code */ -int mca_coll_demo_alltoallw_intra(void *sbuf, int *scounts, int *sdisps, - struct ompi_datatype_t **sdtypes, - void *rbuf, int *rcounts, int *rdisps, - struct ompi_datatype_t **rdtypes, +int mca_coll_demo_alltoallw_intra(const void *sbuf, const int *scounts, const int *sdisps, + struct ompi_datatype_t *const *sdtypes, + void *rbuf, const int *rcounts, const int *rdisps, + struct ompi_datatype_t *const *rdtypes, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo alltoallw_intra"); - return demo_module->underlying.coll_alltoallw(sbuf, scounts, sdisps, - sdtypes, rbuf, rcounts, - rdisps, rdtypes, comm, - demo_module->underlying.coll_alltoallw_module); + return demo_module->c_coll.coll_alltoallw(sbuf, scounts, sdisps, + sdtypes, rbuf, rcounts, + rdisps, rdtypes, comm, + demo_module->c_coll.coll_alltoallw_module); } @@ -56,17 +57,17 @@ int mca_coll_demo_alltoallw_intra(void *sbuf, int *scounts, int *sdisps, * Accepts: - same as MPI_Alltoallw() * Returns: - MPI_SUCCESS or an MPI error code */ -int mca_coll_demo_alltoallw_inter(void *sbuf, int *scounts, int *sdisps, - struct ompi_datatype_t **sdtypes, - void *rbuf, int *rcounts, int *rdisps, - struct ompi_datatype_t **rdtypes, +int mca_coll_demo_alltoallw_inter(const void *sbuf, const int *scounts, const int *sdisps, + struct ompi_datatype_t *const *sdtypes, + void *rbuf, const int *rcounts, const int *rdisps, + struct ompi_datatype_t *const *rdtypes, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo alltoallw_inter"); - return demo_module->underlying.coll_alltoallw(sbuf, scounts, sdisps, - sdtypes, rbuf, rcounts, - rdisps, rdtypes, comm, - demo_module->underlying.coll_alltoallw_module); + return demo_module->c_coll.coll_alltoallw(sbuf, scounts, sdisps, + sdtypes, rbuf, rcounts, + rdisps, rdtypes, comm, + demo_module->c_coll.coll_alltoallw_module); } diff --git a/ompi/mca/coll/demo/coll_demo_barrier.c b/ompi/mca/coll/demo/coll_demo_barrier.c index bcede2bf5b5..7965b8856ae 100644 --- a/ompi/mca/coll/demo/coll_demo_barrier.c +++ b/ompi/mca/coll/demo/coll_demo_barrier.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -38,8 +39,8 @@ int mca_coll_demo_barrier_intra(struct ompi_communicator_t *comm, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo barrier_intra"); - return demo_module->underlying.coll_barrier(comm, - demo_module->underlying.coll_barrier_module); + return demo_module->c_coll.coll_barrier(comm, + demo_module->c_coll.coll_barrier_module); } @@ -55,6 +56,6 @@ int mca_coll_demo_barrier_inter(struct ompi_communicator_t *comm, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo barrier_inter"); - return demo_module->underlying.coll_barrier(comm, - demo_module->underlying.coll_barrier_module); + return demo_module->c_coll.coll_barrier(comm, + demo_module->c_coll.coll_barrier_module); } diff --git a/ompi/mca/coll/demo/coll_demo_bcast.c b/ompi/mca/coll/demo/coll_demo_bcast.c index 645c9e0dd62..3d42555bfc6 100644 --- a/ompi/mca/coll/demo/coll_demo_bcast.c +++ b/ompi/mca/coll/demo/coll_demo_bcast.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -40,9 +41,9 @@ int mca_coll_demo_bcast_intra(void *buff, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo bcast_intra"); - return demo_module->underlying.coll_bcast(buff, count, datatype, - root, comm, - demo_module->underlying.coll_bcast_module); + return demo_module->c_coll.coll_bcast(buff, count, datatype, + root, comm, + demo_module->c_coll.coll_bcast_module); } @@ -60,7 +61,7 @@ int mca_coll_demo_bcast_inter(void *buff, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo bcast_inter"); - return demo_module->underlying.coll_bcast(buff, count, datatype, - root, comm, - demo_module->underlying.coll_bcast_module); + return demo_module->c_coll.coll_bcast(buff, count, datatype, + root, comm, + demo_module->c_coll.coll_bcast_module); } diff --git a/ompi/mca/coll/demo/coll_demo_component.c b/ompi/mca/coll/demo/coll_demo_component.c index ee23252b2e2..80f91da046a 100644 --- a/ompi/mca/coll/demo/coll_demo_component.c +++ b/ompi/mca/coll/demo/coll_demo_component.c @@ -13,6 +13,7 @@ * Copyright (c) 2008 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2015 Los Alamos National Security, LLC. All rights * reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -105,39 +106,10 @@ static int demo_register(void) static void mca_coll_demo_module_construct(mca_coll_demo_module_t *module) { - memset(&module->underlying, 0, sizeof(mca_coll_base_comm_coll_t)); + memset(&module->c_coll, 0, sizeof(mca_coll_base_comm_coll_t)); } -#define RELEASE(module, func) \ - do { \ - if (NULL != module->underlying.coll_ ## func ## _module) { \ - OBJ_RELEASE(module->underlying.coll_ ## func ## _module); \ - } \ - } while (0) - -static void -mca_coll_demo_module_destruct(mca_coll_demo_module_t *module) -{ - RELEASE(module, allgather); - RELEASE(module, allgatherv); - RELEASE(module, allreduce); - RELEASE(module, alltoall); - RELEASE(module, alltoallv); - RELEASE(module, alltoallw); - RELEASE(module, barrier); - RELEASE(module, bcast); - RELEASE(module, exscan); - RELEASE(module, gather); - RELEASE(module, gatherv); - RELEASE(module, reduce); - RELEASE(module, reduce_scatter); - RELEASE(module, scan); - RELEASE(module, scatter); - RELEASE(module, scatterv); -} - - OBJ_CLASS_INSTANCE(mca_coll_demo_module_t, mca_coll_base_module_t, mca_coll_demo_module_construct, - mca_coll_demo_module_destruct); + NULL); diff --git a/ompi/mca/coll/demo/coll_demo_exscan.c b/ompi/mca/coll/demo/coll_demo_exscan.c index c970369d0dd..de9d469321c 100644 --- a/ompi/mca/coll/demo/coll_demo_exscan.c +++ b/ompi/mca/coll/demo/coll_demo_exscan.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same arguments as MPI_Exscan() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_exscan_intra(void *sbuf, void *rbuf, int count, +int mca_coll_demo_exscan_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -41,9 +42,9 @@ int mca_coll_demo_exscan_intra(void *sbuf, void *rbuf, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo exscan_intra"); - return demo_module->underlying.coll_exscan(sbuf, rbuf, count, dtype, - op, comm, - demo_module->underlying.coll_exscan_module); + return demo_module->c_coll.coll_exscan(sbuf, rbuf, count, dtype, + op, comm, + demo_module->c_coll.coll_exscan_module); } /* diff --git a/ompi/mca/coll/demo/coll_demo_gather.c b/ompi/mca/coll/demo/coll_demo_gather.c index 9f9840acf8f..d12fb2ad3c3 100644 --- a/ompi/mca/coll/demo/coll_demo_gather.c +++ b/ompi/mca/coll/demo/coll_demo_gather.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -32,7 +33,7 @@ * Accepts: - same arguments as MPI_Gather() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_gather_intra(void *sbuf, int scount, +int mca_coll_demo_gather_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -41,10 +42,10 @@ int mca_coll_demo_gather_intra(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo gather_intra"); - return demo_module->underlying.coll_gather(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - root, comm, - demo_module->underlying.coll_gather_module); + return demo_module->c_coll.coll_gather(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + root, comm, + demo_module->c_coll.coll_gather_module); } @@ -55,7 +56,7 @@ int mca_coll_demo_gather_intra(void *sbuf, int scount, * Accepts: - same arguments as MPI_Gather() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_gather_inter(void *sbuf, int scount, +int mca_coll_demo_gather_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -64,8 +65,8 @@ int mca_coll_demo_gather_inter(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo gather_inter"); - return demo_module->underlying.coll_gather(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - root, comm, - demo_module->underlying.coll_gather_module); + return demo_module->c_coll.coll_gather(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + root, comm, + demo_module->c_coll.coll_gather_module); } diff --git a/ompi/mca/coll/demo/coll_demo_gatherv.c b/ompi/mca/coll/demo/coll_demo_gatherv.c index f23b37a0d88..082e10153f2 100644 --- a/ompi/mca/coll/demo/coll_demo_gatherv.c +++ b/ompi/mca/coll/demo/coll_demo_gatherv.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,19 +34,19 @@ * Accepts: - same arguments as MPI_Gatherv() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_gatherv_intra(void *sbuf, int scount, +int mca_coll_demo_gatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, - void *rbuf, int *rcounts, int *disps, + void *rbuf, const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo gatherv_intra"); - return demo_module->underlying.coll_gatherv(sbuf, scount, sdtype, - rbuf, rcounts, disps, - rdtype, root, comm, - demo_module->underlying.coll_gatherv_module); + return demo_module->c_coll.coll_gatherv(sbuf, scount, sdtype, + rbuf, rcounts, disps, + rdtype, root, comm, + demo_module->c_coll.coll_gatherv_module); } @@ -56,17 +57,17 @@ int mca_coll_demo_gatherv_intra(void *sbuf, int scount, * Accepts: - same arguments as MPI_Gatherv() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_gatherv_inter(void *sbuf, int scount, +int mca_coll_demo_gatherv_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, - void *rbuf, int *rcounts, int *disps, + void *rbuf, const int *rcounts, const int *disps, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo gatherv_inter"); - return demo_module->underlying.coll_gatherv(sbuf, scount, sdtype, - rbuf, rcounts, disps, - rdtype, root, comm, - demo_module->underlying.coll_gatherv_module); + return demo_module->c_coll.coll_gatherv(sbuf, scount, sdtype, + rbuf, rcounts, disps, + rdtype, root, comm, + demo_module->c_coll.coll_gatherv_module); } diff --git a/ompi/mca/coll/demo/coll_demo_module.c b/ompi/mca/coll/demo/coll_demo_module.c index 5c66c98059f..0201340d9d9 100644 --- a/ompi/mca/coll/demo/coll_demo_module.c +++ b/ompi/mca/coll/demo/coll_demo_module.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -27,71 +28,6 @@ #include "ompi/mca/coll/base/base.h" #include "coll_demo.h" -#if 0 - -/* - * Linear set of collective algorithms - */ -static const mca_coll_base_module_1_0_0_t intra = { - - /* Initialization / finalization functions */ - - mca_coll_demo_module_init, - mca_coll_demo_module_finalize, - - /* Collective function pointers */ - - mca_coll_demo_allgather_intra, - mca_coll_demo_allgatherv_intra, - mca_coll_demo_allreduce_intra, - mca_coll_demo_alltoall_intra, - mca_coll_demo_alltoallv_intra, - mca_coll_demo_alltoallw_intra, - mca_coll_demo_barrier_intra, - mca_coll_demo_bcast_intra, - NULL, /* Leave exscan blank just to force basic to be used */ - mca_coll_demo_gather_intra, - mca_coll_demo_gatherv_intra, - mca_coll_demo_reduce_intra, - mca_coll_demo_reduce_scatter_intra, - mca_coll_demo_scan_intra, - mca_coll_demo_scatter_intra, - mca_coll_demo_scatterv_intra -}; - - -/* - * Linear set of collective algorithms for intercommunicators - */ -static const mca_coll_base_module_1_0_0_t inter = { - - /* Initialization / finalization functions */ - - mca_coll_demo_module_init, - mca_coll_demo_module_finalize, - - /* Collective function pointers */ - - mca_coll_demo_allgather_inter, - mca_coll_demo_allgatherv_inter, - mca_coll_demo_allreduce_inter, - mca_coll_demo_alltoall_inter, - mca_coll_demo_alltoallv_inter, - mca_coll_demo_alltoallw_inter, - mca_coll_demo_barrier_inter, - mca_coll_demo_bcast_inter, - mca_coll_demo_exscan_inter, - mca_coll_demo_gather_inter, - mca_coll_demo_gatherv_inter, - mca_coll_demo_reduce_inter, - mca_coll_demo_reduce_scatter_inter, - NULL, - mca_coll_demo_scatter_inter, - mca_coll_demo_scatterv_inter -}; - -#endif - /* * Initial query function that is invoked during MPI_INIT, allowing * this component to disqualify itself if it doesn't support the @@ -105,6 +41,13 @@ int mca_coll_demo_init_query(bool enable_progress_threads, return OMPI_SUCCESS; } +static int +mca_coll_demo_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); +static int +mca_coll_demo_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); + /* * Invoked when there's a new communicator that has been created. * Look at the communicator and decide which set of functions and @@ -121,54 +64,38 @@ mca_coll_demo_comm_query(struct ompi_communicator_t *comm, int *priority) *priority = mca_coll_demo_priority; demo_module->super.coll_module_enable = mca_coll_demo_module_enable; - - if (OMPI_COMM_IS_INTRA(comm)) { - demo_module->super.coll_allgather = mca_coll_demo_allgather_intra; - demo_module->super.coll_allgatherv = mca_coll_demo_allgatherv_intra; - demo_module->super.coll_allreduce = mca_coll_demo_allreduce_intra; - demo_module->super.coll_alltoall = mca_coll_demo_alltoall_intra; - demo_module->super.coll_alltoallv = mca_coll_demo_alltoallv_intra; - demo_module->super.coll_alltoallw = mca_coll_demo_alltoallw_intra; - demo_module->super.coll_barrier = mca_coll_demo_barrier_intra; - demo_module->super.coll_bcast = mca_coll_demo_bcast_intra; - demo_module->super.coll_exscan = mca_coll_demo_exscan_intra; - demo_module->super.coll_gather = mca_coll_demo_gather_intra; - demo_module->super.coll_gatherv = mca_coll_demo_gatherv_intra; - demo_module->super.coll_reduce = mca_coll_demo_reduce_intra; - demo_module->super.coll_reduce_scatter = mca_coll_demo_reduce_scatter_intra; - demo_module->super.coll_scan = mca_coll_demo_scan_intra; - demo_module->super.coll_scatter = mca_coll_demo_scatter_intra; - demo_module->super.coll_scatterv = mca_coll_demo_scatterv_intra; - } else { - demo_module->super.coll_allgather = mca_coll_demo_allgather_inter; - demo_module->super.coll_allgatherv = mca_coll_demo_allgatherv_inter; - demo_module->super.coll_allreduce = mca_coll_demo_allreduce_inter; - demo_module->super.coll_alltoall = mca_coll_demo_alltoall_inter; - demo_module->super.coll_alltoallv = mca_coll_demo_alltoallv_inter; - demo_module->super.coll_alltoallw = mca_coll_demo_alltoallw_inter; - demo_module->super.coll_barrier = mca_coll_demo_barrier_inter; - demo_module->super.coll_bcast = mca_coll_demo_bcast_inter; - demo_module->super.coll_exscan = NULL; - demo_module->super.coll_gather = mca_coll_demo_gather_inter; - demo_module->super.coll_gatherv = mca_coll_demo_gatherv_inter; - demo_module->super.coll_reduce = mca_coll_demo_reduce_inter; - demo_module->super.coll_reduce_scatter = mca_coll_demo_reduce_scatter_inter; - demo_module->super.coll_scan = NULL; - demo_module->super.coll_scatter = mca_coll_demo_scatter_inter; - demo_module->super.coll_scatterv = mca_coll_demo_scatterv_inter; - } + demo_module->super.coll_module_disable = mca_coll_demo_module_disable; return &(demo_module->super); } -#define COPY(comm, module, func) \ - do { \ - module->underlying.coll_ ## func = comm->c_coll->coll_ ## func; \ - module->underlying.coll_ ## func = comm->c_coll->coll_ ## func; \ - if (NULL != module->underlying.coll_ ## func ## _module) { \ - OBJ_RETAIN(module->underlying.coll_ ## func ## _module); \ - } \ - } while (0) +#define DEMO_INSTALL_COLL_API(__comm, __module, __api, __func) \ + do \ + { \ + if (__comm->c_coll->coll_##__api) \ + { \ + /* save the current selected collective */ \ + MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "demo"); \ + /* install our own */ \ + MCA_COLL_INSTALL_API(__comm, __api, __func, &__module->super, "demo"); \ + } \ + else \ + { \ + opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \ + "demo", #__api, ompi_process_info.nodename, \ + mca_coll_demo_priority); \ + } \ + } while (0) + +#define DEMO_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (&(__module)->super == __comm->c_coll->coll_##__api##_module) \ + { \ + /* put back the original collective */ \ + MCA_COLL_INSTALL_API(__comm, __api, __module->c_coll.coll_##__api, __module->c_coll.coll_##__api##_module, "demo"); \ + } \ + } while (0) int mca_coll_demo_module_enable(mca_coll_base_module_t *module, @@ -180,23 +107,70 @@ mca_coll_demo_module_enable(mca_coll_base_module_t *module, printf("Hello! This is the \"demo\" coll component. I'll be your coll component\ntoday. Please tip your waitresses well.\n"); } - /* save the old pointers */ - COPY(comm, demo_module, allgather); - COPY(comm, demo_module, allgatherv); - COPY(comm, demo_module, allreduce); - COPY(comm, demo_module, alltoall); - COPY(comm, demo_module, alltoallv); - COPY(comm, demo_module, alltoallw); - COPY(comm, demo_module, barrier); - COPY(comm, demo_module, bcast); - COPY(comm, demo_module, exscan); - COPY(comm, demo_module, gather); - COPY(comm, demo_module, gatherv); - COPY(comm, demo_module, reduce); - COPY(comm, demo_module, reduce_scatter); - COPY(comm, demo_module, scan); - COPY(comm, demo_module, scatter); - COPY(comm, demo_module, scatterv); + if (OMPI_COMM_IS_INTRA(comm)) + { + DEMO_INSTALL_COLL_API(comm, demo_module, allgather, mca_coll_demo_allgather_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, allgatherv, mca_coll_demo_allgatherv_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, allreduce, mca_coll_demo_allreduce_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, alltoall, mca_coll_demo_alltoall_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, alltoallv, mca_coll_demo_alltoallv_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, alltoallw, mca_coll_demo_alltoallw_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, barrier, mca_coll_demo_barrier_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, bcast, mca_coll_demo_bcast_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, exscan, mca_coll_demo_exscan_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, gather, mca_coll_demo_gather_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, gatherv, mca_coll_demo_gatherv_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, reduce, mca_coll_demo_reduce_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, reduce_scatter, mca_coll_demo_reduce_scatter_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, scan, mca_coll_demo_scan_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, scatter, mca_coll_demo_scatter_intra); + DEMO_INSTALL_COLL_API(comm, demo_module, scatterv, mca_coll_demo_scatterv_intra); + } + else + { + DEMO_INSTALL_COLL_API(comm, demo_module, allgather, mca_coll_demo_allgather_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, allgatherv, mca_coll_demo_allgatherv_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, allreduce, mca_coll_demo_allreduce_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, alltoall, mca_coll_demo_alltoall_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, alltoallv, mca_coll_demo_alltoallv_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, alltoallw, mca_coll_demo_alltoallw_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, barrier, mca_coll_demo_barrier_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, bcast, mca_coll_demo_bcast_inter); + /* Skip the exscan for inter-comms, it is not supported */ + DEMO_INSTALL_COLL_API(comm, demo_module, gather, mca_coll_demo_gather_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, gatherv, mca_coll_demo_gatherv_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, reduce, mca_coll_demo_reduce_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, reduce_scatter, mca_coll_demo_reduce_scatter_inter); + /* Skip the scan for inter-comms, it is not supported */ + DEMO_INSTALL_COLL_API(comm, demo_module, scatter, mca_coll_demo_scatter_inter); + DEMO_INSTALL_COLL_API(comm, demo_module, scatterv, mca_coll_demo_scatterv_inter); + } + return OMPI_SUCCESS; +} + +static int +mca_coll_demo_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; + + /* put back the old pointers */ + DEMO_UNINSTALL_COLL_API(comm, demo_module, allgather); + DEMO_UNINSTALL_COLL_API(comm, demo_module, allgatherv); + DEMO_UNINSTALL_COLL_API(comm, demo_module, allreduce); + DEMO_UNINSTALL_COLL_API(comm, demo_module, alltoall); + DEMO_UNINSTALL_COLL_API(comm, demo_module, alltoallv); + DEMO_UNINSTALL_COLL_API(comm, demo_module, alltoallw); + DEMO_UNINSTALL_COLL_API(comm, demo_module, barrier); + DEMO_UNINSTALL_COLL_API(comm, demo_module, bcast); + DEMO_UNINSTALL_COLL_API(comm, demo_module, exscan); + DEMO_UNINSTALL_COLL_API(comm, demo_module, gather); + DEMO_UNINSTALL_COLL_API(comm, demo_module, gatherv); + DEMO_UNINSTALL_COLL_API(comm, demo_module, reduce); + DEMO_UNINSTALL_COLL_API(comm, demo_module, reduce_scatter); + DEMO_UNINSTALL_COLL_API(comm, demo_module, scan); + DEMO_UNINSTALL_COLL_API(comm, demo_module, scatter); + DEMO_UNINSTALL_COLL_API(comm, demo_module, scatterv); return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/demo/coll_demo_reduce.c b/ompi/mca/coll/demo/coll_demo_reduce.c index 6df413902b6..03875148507 100644 --- a/ompi/mca/coll/demo/coll_demo_reduce.c +++ b/ompi/mca/coll/demo/coll_demo_reduce.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same as MPI_Reduce() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_reduce_intra(void *sbuf, void *rbuf, int count, +int mca_coll_demo_reduce_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, struct ompi_communicator_t *comm, @@ -41,9 +42,9 @@ int mca_coll_demo_reduce_intra(void *sbuf, void *rbuf, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo reduce_intra"); - return demo_module->underlying.coll_reduce(sbuf, rbuf, count, dtype, - op, root, comm, - demo_module->underlying.coll_reduce_module); + return demo_module->c_coll.coll_reduce(sbuf, rbuf, count, dtype, + op, root, comm, + demo_module->c_coll.coll_reduce_module); } @@ -54,7 +55,7 @@ int mca_coll_demo_reduce_intra(void *sbuf, void *rbuf, int count, * Accepts: - same as MPI_Reduce() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_reduce_inter(void *sbuf, void *rbuf, int count, +int mca_coll_demo_reduce_inter(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, struct ompi_communicator_t *comm, @@ -62,7 +63,7 @@ int mca_coll_demo_reduce_inter(void *sbuf, void *rbuf, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo reduce_inter"); - return demo_module->underlying.coll_reduce(sbuf, rbuf, count, dtype, - op, root, comm, - demo_module->underlying.coll_reduce_module); + return demo_module->c_coll.coll_reduce(sbuf, rbuf, count, dtype, + op, root, comm, + demo_module->c_coll.coll_reduce_module); } diff --git a/ompi/mca/coll/demo/coll_demo_reduce_scatter.c b/ompi/mca/coll/demo/coll_demo_reduce_scatter.c index 438f1008b3a..e77babd73c1 100644 --- a/ompi/mca/coll/demo/coll_demo_reduce_scatter.c +++ b/ompi/mca/coll/demo/coll_demo_reduce_scatter.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same as MPI_Reduce_scatter() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_reduce_scatter_intra(void *sbuf, void *rbuf, int *rcounts, +int mca_coll_demo_reduce_scatter_intra(const void *sbuf, void *rbuf, const int *rcounts, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -41,9 +42,9 @@ int mca_coll_demo_reduce_scatter_intra(void *sbuf, void *rbuf, int *rcounts, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scatter_intra"); - return demo_module->underlying.coll_reduce_scatter(sbuf, rbuf, rcounts, - dtype, op, comm, - demo_module->underlying.coll_reduce_scatter_module); + return demo_module->c_coll.coll_reduce_scatter(sbuf, rbuf, rcounts, + dtype, op, comm, + demo_module->c_coll.coll_reduce_scatter_module); } @@ -54,7 +55,7 @@ int mca_coll_demo_reduce_scatter_intra(void *sbuf, void *rbuf, int *rcounts, * Accepts: - same arguments as MPI_Reduce_scatter() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_reduce_scatter_inter(void *sbuf, void *rbuf, int *rcounts, +int mca_coll_demo_reduce_scatter_inter(const void *sbuf, void *rbuf, const int *rcounts, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -62,7 +63,7 @@ int mca_coll_demo_reduce_scatter_inter(void *sbuf, void *rbuf, int *rcounts, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scatter_inter"); - return demo_module->underlying.coll_reduce_scatter(sbuf, rbuf, rcounts, - dtype, op, comm, - demo_module->underlying.coll_reduce_scatter_module); + return demo_module->c_coll.coll_reduce_scatter(sbuf, rbuf, rcounts, + dtype, op, comm, + demo_module->c_coll.coll_reduce_scatter_module); } diff --git a/ompi/mca/coll/demo/coll_demo_scan.c b/ompi/mca/coll/demo/coll_demo_scan.c index 90d3cb343b1..614d30e4ef8 100644 --- a/ompi/mca/coll/demo/coll_demo_scan.c +++ b/ompi/mca/coll/demo/coll_demo_scan.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same arguments as MPI_Scan() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_scan_intra(void *sbuf, void *rbuf, int count, +int mca_coll_demo_scan_intra(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -41,9 +42,9 @@ int mca_coll_demo_scan_intra(void *sbuf, void *rbuf, int count, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scan_intra"); - return demo_module->underlying.coll_scan(sbuf, rbuf, count, - dtype, op, comm, - demo_module->underlying.coll_scan_module); + return demo_module->c_coll.coll_scan(sbuf, rbuf, count, + dtype, op, comm, + demo_module->c_coll.coll_scan_module); } diff --git a/ompi/mca/coll/demo/coll_demo_scatter.c b/ompi/mca/coll/demo/coll_demo_scatter.c index ccc2e401df6..f012e0dfb5f 100644 --- a/ompi/mca/coll/demo/coll_demo_scatter.c +++ b/ompi/mca/coll/demo/coll_demo_scatter.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,7 +34,7 @@ * Accepts: - same arguments as MPI_Scatter() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_scatter_intra(void *sbuf, int scount, +int mca_coll_demo_scatter_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -43,10 +44,10 @@ int mca_coll_demo_scatter_intra(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scatter_intra"); - return demo_module->underlying.coll_scatter(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - root, comm, - demo_module->underlying.coll_scatter_module); + return demo_module->c_coll.coll_scatter(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + root, comm, + demo_module->c_coll.coll_scatter_module); } @@ -57,7 +58,7 @@ int mca_coll_demo_scatter_intra(void *sbuf, int scount, * Accepts: - same arguments as MPI_Scatter() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_scatter_inter(void *sbuf, int scount, +int mca_coll_demo_scatter_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -67,8 +68,8 @@ int mca_coll_demo_scatter_inter(void *sbuf, int scount, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scatter_inter"); - return demo_module->underlying.coll_scatter(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - root, comm, - demo_module->underlying.coll_scatter_module); + return demo_module->c_coll.coll_scatter(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + root, comm, + demo_module->c_coll.coll_scatter_module); } diff --git a/ompi/mca/coll/demo/coll_demo_scatterv.c b/ompi/mca/coll/demo/coll_demo_scatterv.c index 3084efc0de5..8e2101a8f36 100644 --- a/ompi/mca/coll/demo/coll_demo_scatterv.c +++ b/ompi/mca/coll/demo/coll_demo_scatterv.c @@ -9,6 +9,7 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,8 +34,8 @@ * Accepts: - same arguments as MPI_Scatterv() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_scatterv_intra(void *sbuf, int *scounts, - int *disps, struct ompi_datatype_t *sdtype, +int mca_coll_demo_scatterv_intra(const void *sbuf, const int *scounts, + const int *disps, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, @@ -42,10 +43,10 @@ int mca_coll_demo_scatterv_intra(void *sbuf, int *scounts, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scatterv_intra"); - return demo_module->underlying.coll_scatterv(sbuf, scounts, disps, - sdtype, rbuf, rcount, - rdtype, root, comm, - demo_module->underlying.coll_scatterv_module); + return demo_module->c_coll.coll_scatterv(sbuf, scounts, disps, + sdtype, rbuf, rcount, + rdtype, root, comm, + demo_module->c_coll.coll_scatterv_module); } @@ -56,8 +57,8 @@ int mca_coll_demo_scatterv_intra(void *sbuf, int *scounts, * Accepts: - same arguments as MPI_Scatterv() * Returns: - MPI_SUCCESS or error code */ -int mca_coll_demo_scatterv_inter(void *sbuf, int *scounts, - int *disps, struct ompi_datatype_t *sdtype, +int mca_coll_demo_scatterv_inter(const void *sbuf, const int *scounts, + const int *disps, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, @@ -65,8 +66,8 @@ int mca_coll_demo_scatterv_inter(void *sbuf, int *scounts, { mca_coll_demo_module_t *demo_module = (mca_coll_demo_module_t*) module; opal_output_verbose(10, ompi_coll_base_framework.framework_output, "In demo scatterv_inter"); - return demo_module->underlying.coll_scatterv(sbuf, scounts, disps, - sdtype, rbuf, rcount, - rdtype, root, comm, - demo_module->underlying.coll_scatterv_module); + return demo_module->c_coll.coll_scatterv(sbuf, scounts, disps, + sdtype, rbuf, rcount, + rdtype, root, comm, + demo_module->c_coll.coll_scatterv_module); } diff --git a/ompi/mca/coll/ftagree/coll_ftagree.h b/ompi/mca/coll/ftagree/coll_ftagree.h index 86e0d4da314..9ca5b8b9b1f 100644 --- a/ompi/mca/coll/ftagree/coll_ftagree.h +++ b/ompi/mca/coll/ftagree/coll_ftagree.h @@ -3,6 +3,7 @@ * Copyright (c) 2012-2020 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -98,9 +99,6 @@ mca_coll_base_module_t *mca_coll_ftagree_comm_query(struct ompi_communicator_t *comm, int *priority); -int mca_coll_ftagree_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - /* * Agreement algorithms */ diff --git a/ompi/mca/coll/ftagree/coll_ftagree_module.c b/ompi/mca/coll/ftagree/coll_ftagree_module.c index 6947139a7f5..755fca1b9bc 100644 --- a/ompi/mca/coll/ftagree/coll_ftagree_module.c +++ b/ompi/mca/coll/ftagree/coll_ftagree_module.c @@ -2,6 +2,7 @@ * Copyright (c) 2012-2020 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -33,6 +34,30 @@ mca_coll_ftagree_init_query(bool enable_progress_threads, return OMPI_SUCCESS; } +/* + * Init/Fini module on the communicator + */ +static int +mca_coll_ftagree_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + MCA_COLL_INSTALL_API(comm, agree, module->coll_agree, module, "ftagree"); + MCA_COLL_INSTALL_API(comm, iagree, module->coll_iagree, module, "ftagree"); + + /* All done */ + return OMPI_SUCCESS; +} + +static int +mca_coll_ftagree_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + MCA_COLL_INSTALL_API(comm, agree, NULL, NULL, "ftagree"); + MCA_COLL_INSTALL_API(comm, iagree, NULL, NULL, "ftagree"); + + /* All done */ + return OMPI_SUCCESS; +} /* * Invoked when there's a new communicator that has been created. @@ -76,6 +101,7 @@ mca_coll_ftagree_comm_query(struct ompi_communicator_t *comm, * algorithms. */ ftagree_module->super.coll_module_enable = mca_coll_ftagree_module_enable; + ftagree_module->super.coll_module_disable = mca_coll_ftagree_module_disable; /* This component does not provide any base collectives, * just the FT collectives. @@ -112,15 +138,3 @@ mca_coll_ftagree_comm_query(struct ompi_communicator_t *comm, return &(ftagree_module->super); } - -/* - * Init module on the communicator - */ -int -mca_coll_ftagree_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm) -{ - /* All done */ - return OMPI_SUCCESS; -} - diff --git a/ompi/mca/coll/han/coll_han.h b/ompi/mca/coll/han/coll_han.h index 29047abb317..4e5323fc046 100644 --- a/ompi/mca/coll/han/coll_han.h +++ b/ompi/mca/coll/han/coll_han.h @@ -6,6 +6,9 @@ * Copyright (c) 2020-2022 Bull S.A.S. All rights reserved. * Copyright (c) Amazon.com, Inc. or its affiliates. * All rights reserved. + * Copyright (c) 2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -46,11 +49,11 @@ /* * Today; - * . only 2 modules available for intranode (low) level + * . 3 modules available for intranode (low) level * . only 2 modules available for internode (up) level */ -#define COLL_HAN_LOW_MODULES 2 +#define COLL_HAN_LOW_MODULES 3 #define COLL_HAN_UP_MODULES 2 struct mca_coll_han_bcast_args_s { @@ -276,13 +279,14 @@ typedef struct mca_coll_han_component_t { int max_dynamic_errors; } mca_coll_han_component_t; - /* * Structure used to store what is necessary for the collective operations * routines in case of fallback. */ -typedef struct mca_coll_han_single_collective_fallback_s { - union { +typedef struct mca_coll_han_single_collective_fallback_s +{ + union + { mca_coll_base_module_allgather_fn_t allgather; mca_coll_base_module_allgatherv_fn_t allgatherv; mca_coll_base_module_allreduce_fn_t allreduce; @@ -293,7 +297,7 @@ typedef struct mca_coll_han_single_collective_fallback_s { mca_coll_base_module_reduce_fn_t reduce; mca_coll_base_module_scatter_fn_t scatter; mca_coll_base_module_scatterv_fn_t scatterv; - } module_fn; + }; mca_coll_base_module_t* module; } mca_coll_han_single_collective_fallback_t; @@ -302,7 +306,8 @@ typedef struct mca_coll_han_single_collective_fallback_s { * by HAN. This structure is used as a fallback during subcommunicator * creation. */ -typedef struct mca_coll_han_collectives_fallback_s { +typedef struct mca_coll_han_collectives_fallback_s +{ mca_coll_han_single_collective_fallback_t allgather; mca_coll_han_single_collective_fallback_t allgatherv; mca_coll_han_single_collective_fallback_t allreduce; @@ -334,7 +339,7 @@ typedef struct mca_coll_han_module_t { bool is_heterogeneous; /* To be able to fallback when the cases are not supported */ - struct mca_coll_han_collectives_fallback_s fallback; + mca_coll_han_collectives_fallback_t fallback; /* To be able to fallback on reproducible algorithm */ mca_coll_base_module_reduce_fn_t reproducible_reduce; @@ -365,61 +370,61 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t); * Some defines to stick to the naming used in the other components in terms of * fallback routines */ -#define previous_allgather fallback.allgather.module_fn.allgather +#define previous_allgather fallback.allgather.allgather #define previous_allgather_module fallback.allgather.module -#define previous_allgatherv fallback.allgatherv.module_fn.allgatherv +#define previous_allgatherv fallback.allgatherv.allgatherv #define previous_allgatherv_module fallback.allgatherv.module -#define previous_allreduce fallback.allreduce.module_fn.allreduce +#define previous_allreduce fallback.allreduce.allreduce #define previous_allreduce_module fallback.allreduce.module -#define previous_barrier fallback.barrier.module_fn.barrier +#define previous_barrier fallback.barrier.barrier #define previous_barrier_module fallback.barrier.module -#define previous_bcast fallback.bcast.module_fn.bcast +#define previous_bcast fallback.bcast.bcast #define previous_bcast_module fallback.bcast.module -#define previous_reduce fallback.reduce.module_fn.reduce +#define previous_reduce fallback.reduce.reduce #define previous_reduce_module fallback.reduce.module -#define previous_gather fallback.gather.module_fn.gather +#define previous_gather fallback.gather.gather #define previous_gather_module fallback.gather.module -#define previous_gatherv fallback.gatherv.module_fn.gatherv +#define previous_gatherv fallback.gatherv.gatherv #define previous_gatherv_module fallback.gatherv.module -#define previous_scatter fallback.scatter.module_fn.scatter +#define previous_scatter fallback.scatter.scatter #define previous_scatter_module fallback.scatter.module -#define previous_scatterv fallback.scatterv.module_fn.scatterv +#define previous_scatterv fallback.scatterv.scatterv #define previous_scatterv_module fallback.scatterv.module /* macro to correctly load a fallback collective module */ -#define HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, COLL) \ - do { \ - if ( ((COMM)->c_coll->coll_ ## COLL ## _module) == (mca_coll_base_module_t*)(HANM) ) { \ - (COMM)->c_coll->coll_ ## COLL = (HANM)->previous_## COLL; \ - mca_coll_base_module_t *coll_module = (COMM)->c_coll->coll_ ## COLL ## _module; \ - (COMM)->c_coll->coll_ ## COLL ## _module = (HANM)->previous_ ## COLL ## _module; \ - OBJ_RETAIN((COMM)->c_coll->coll_ ## COLL ## _module); \ - OBJ_RELEASE(coll_module); \ - } \ - } while(0) - -/* macro to correctly load /all/ fallback collectives */ -#define HAN_LOAD_FALLBACK_COLLECTIVES(HANM, COMM) \ - do { \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, barrier); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, bcast); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatter); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatterv); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gather); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gatherv); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, reduce); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allreduce); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allgather); \ - HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allgatherv); \ +#define HAN_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, \ + __module->previous_##__api##_module, "han"); \ + /* Do not reset the fallback to NULL it will be needed */ \ + } \ + } while (0) + + /* macro to correctly load /all/ fallback collectives */ +#define HAN_LOAD_FALLBACK_COLLECTIVES(COMM, HANM) \ + do { \ + HAN_UNINSTALL_COLL_API(COMM, HANM, barrier); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, bcast); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, scatter); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, scatterv); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, gather); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, gatherv); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, reduce); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, allreduce); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, allgather); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, allgatherv); \ han_module->enabled = false; /* entire module set to pass-through from now on */ \ } while(0) diff --git a/ompi/mca/coll/han/coll_han_allgather.c b/ompi/mca/coll/han/coll_han_allgather.c index fa827dccf09..8c57aeffb5b 100644 --- a/ompi/mca/coll/han/coll_han_allgather.c +++ b/ompi/mca/coll/han/coll_han_allgather.c @@ -4,6 +4,7 @@ * reserved. * Copyright (c) 2020 Bull S.A.S. All rights reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -82,7 +83,7 @@ mca_coll_han_allgather_intra(const void *sbuf, int scount, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle allgather within this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, han_module->previous_allgather_module); } @@ -97,7 +98,7 @@ mca_coll_han_allgather_intra(const void *sbuf, int scount, if (han_module->are_ppn_imbalanced) { OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle allgather with this communicator (imbalance). Fall back on another component\n")); - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, allgather); + HAN_UNINSTALL_COLL_API(comm, han_module, allgather); return han_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, han_module->previous_allgather_module); } @@ -306,7 +307,7 @@ mca_coll_han_allgather_intra_simple(const void *sbuf, int scount, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle allgather within this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, han_module->previous_allgather_module); } @@ -320,7 +321,7 @@ mca_coll_han_allgather_intra_simple(const void *sbuf, int scount, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, allgather); + HAN_UNINSTALL_COLL_API(comm, han_module, allgather); return han_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, han_module->previous_allgather_module); } diff --git a/ompi/mca/coll/han/coll_han_allreduce.c b/ompi/mca/coll/han/coll_han_allreduce.c index 039913d7fdb..81212145d41 100644 --- a/ompi/mca/coll/han/coll_han_allreduce.c +++ b/ompi/mca/coll/han/coll_han_allreduce.c @@ -6,6 +6,7 @@ * * Copyright (c) 2020 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -109,7 +110,7 @@ mca_coll_han_allreduce_intra(const void *sbuf, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle allreduce with this communicator. Drop HAN support in this communicator and fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_allreduce(sbuf, rbuf, count, dtype, op, comm, han_module->previous_allreduce_module); } @@ -495,7 +496,7 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle allreduce with this communicator. Drop HAN support in this communicator and fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_allreduce(sbuf, rbuf, count, dtype, op, comm, han_module->previous_allreduce_module); } diff --git a/ompi/mca/coll/han/coll_han_barrier.c b/ompi/mca/coll/han/coll_han_barrier.c index 6f437135ced..19b0e38c9a7 100644 --- a/ompi/mca/coll/han/coll_han_barrier.c +++ b/ompi/mca/coll/han/coll_han_barrier.c @@ -3,6 +3,7 @@ * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2020 Bull S.A.S. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -36,10 +37,8 @@ mca_coll_han_barrier_intra_simple(struct ompi_communicator_t *comm, if( OMPI_SUCCESS != mca_coll_han_comm_create_new(comm, han_module) ) { OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle barrier with this communicator. Fall back on another component\n")); - /* Put back the fallback collective support and call it once. All - * future calls will then be automatically redirected. - */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + /* Try to put back the fallback collective support and call it once. */ + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_barrier(comm, han_module->previous_barrier_module); } diff --git a/ompi/mca/coll/han/coll_han_bcast.c b/ompi/mca/coll/han/coll_han_bcast.c index 33a3cc8af44..b543c10db38 100644 --- a/ompi/mca/coll/han/coll_han_bcast.c +++ b/ompi/mca/coll/han/coll_han_bcast.c @@ -4,6 +4,7 @@ * reserved. * Copyright (c) 2020 Bull S.A.S. All rights reserved. * Copyright (c) 2020 Cisco Systems, Inc. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -83,7 +84,7 @@ mca_coll_han_bcast_intra(void *buf, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_bcast(buf, count, dtype, root, comm, han_module->previous_bcast_module); } @@ -96,7 +97,7 @@ mca_coll_han_bcast_intra(void *buf, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, bcast); + HAN_UNINSTALL_COLL_API(comm, han_module, bcast); return han_module->previous_bcast(buf, count, dtype, root, comm, han_module->previous_bcast_module); } @@ -245,7 +246,7 @@ mca_coll_han_bcast_intra_simple(void *buf, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_bcast(buf, count, dtype, root, comm, han_module->previous_bcast_module); } @@ -258,7 +259,7 @@ mca_coll_han_bcast_intra_simple(void *buf, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, bcast); + HAN_UNINSTALL_COLL_API(comm, han_module, bcast); return han_module->previous_bcast(buf, count, dtype, root, comm, han_module->previous_bcast_module); } diff --git a/ompi/mca/coll/han/coll_han_component.c b/ompi/mca/coll/han/coll_han_component.c index ed9582d5ffe..6ce8f5a06a8 100644 --- a/ompi/mca/coll/han/coll_han_component.c +++ b/ompi/mca/coll/han/coll_han_component.c @@ -4,6 +4,8 @@ * reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved * Copyright (c) 2020-2022 Bull S.A.S. All rights reserved. + * Copyright (c) 2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -43,7 +45,8 @@ ompi_coll_han_components ompi_coll_han_available_components[COMPONENTS_COUNT] = { TUNED, "tuned" }, { SM, "sm" }, /* this should not be used, the collective component is gone */ { ADAPT, "adapt" }, - { HAN, "han" } + { HAN, "han" }, + { XHC, "xhc" } }; /* @@ -287,7 +290,7 @@ static int han_register(void) cs->han_bcast_low_module = 0; (void) mca_coll_han_query_module_from_mca(c, "bcast_low_module", - "low level module for bcast, currently only 0 for tuned", + "low level module for bcast, 0 tuned, 2 xhc", OPAL_INFO_LVL_9, &cs->han_bcast_low_module, &cs->han_op_module_name.bcast.han_op_low_module_name); @@ -307,7 +310,7 @@ static int han_register(void) cs->han_reduce_low_module = 0; (void) mca_coll_han_query_module_from_mca(c, "reduce_low_module", - "low level module for allreduce, currently only 0 tuned", + "low level module for allreduce, 0 tuned, 2 xhc", OPAL_INFO_LVL_9, &cs->han_reduce_low_module, &cs->han_op_module_name.reduce.han_op_low_module_name); @@ -326,7 +329,7 @@ static int han_register(void) cs->han_allreduce_low_module = 0; (void) mca_coll_han_query_module_from_mca(c, "allreduce_low_module", - "low level module for allreduce, currently only 0 tuned", + "low level module for allreduce, 0 tuned, 2 xhc", OPAL_INFO_LVL_9, &cs->han_allreduce_low_module, &cs->han_op_module_name.allreduce.han_op_low_module_name); @@ -338,7 +341,7 @@ static int han_register(void) cs->han_allgather_low_module = 0; (void) mca_coll_han_query_module_from_mca(c, "allgather_low_module", - "low level module for allgather, currently only 0 tuned", + "low level module for allgather, 0 tuned, 2 xhc", OPAL_INFO_LVL_9, &cs->han_allgather_low_module, &cs->han_op_module_name.allgather.han_op_low_module_name); @@ -350,7 +353,7 @@ static int han_register(void) cs->han_gather_low_module = 0; (void) mca_coll_han_query_module_from_mca(c, "gather_low_module", - "low level module for gather, currently only 0 tuned", + "low level module for gather, 0 tuned, 2 xhc", OPAL_INFO_LVL_9, &cs->han_gather_low_module, &cs->han_op_module_name.gather.han_op_low_module_name); @@ -374,7 +377,7 @@ static int han_register(void) cs->han_scatter_low_module = 0; (void) mca_coll_han_query_module_from_mca(c, "scatter_low_module", - "low level module for scatter, currently only 0 tuned", + "low level module for scatter, 0 tuned, 2 xhc", OPAL_INFO_LVL_9, &cs->han_scatter_low_module, &cs->han_op_module_name.scatter.han_op_low_module_name); diff --git a/ompi/mca/coll/han/coll_han_dynamic.h b/ompi/mca/coll/han/coll_han_dynamic.h index 403e458391e..82114227308 100644 --- a/ompi/mca/coll/han/coll_han_dynamic.h +++ b/ompi/mca/coll/han/coll_han_dynamic.h @@ -5,6 +5,8 @@ * reserved. * Copyright (c) 2020 Bull S.A.S. All rights reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. * * $COPYRIGHT$ * @@ -105,6 +107,7 @@ typedef enum COMPONENTS { SM, ADAPT, HAN, + XHC, COMPONENTS_COUNT } COMPONENT_T; diff --git a/ompi/mca/coll/han/coll_han_gather.c b/ompi/mca/coll/han/coll_han_gather.c index 8aaeddb8d26..256e4e714f0 100644 --- a/ompi/mca/coll/han/coll_han_gather.c +++ b/ompi/mca/coll/han/coll_han_gather.c @@ -5,6 +5,7 @@ * Copyright (c) 2020 Bull S.A.S. All rights reserved. * Copyright (c) 2020 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -92,7 +93,7 @@ mca_coll_han_gather_intra(const void *sbuf, int scount, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle gather with this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_gather_module); } @@ -103,10 +104,8 @@ mca_coll_han_gather_intra(const void *sbuf, int scount, if (han_module->are_ppn_imbalanced) { OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle gather with this communicator (imbalance). Fall back on another component\n")); - /* Put back the fallback collective support and call it once. All - * future calls will then be automatically redirected. - */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, gather); + /* Unregister HAN gather if possible, and execute the fallback gather */ + HAN_UNINSTALL_COLL_API(comm, han_module, gather); return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_gather_module); } @@ -319,7 +318,7 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle gather with this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_gather_module); } @@ -330,10 +329,8 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount, if (han_module->are_ppn_imbalanced){ OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle gather with this communicator (imbalance). Fall back on another component\n")); - /* Put back the fallback collective support and call it once. All - * future calls will then be automatically redirected. - */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, gather); + /* Unregister HAN gather if possible, and execute the fallback gather */ + HAN_UNINSTALL_COLL_API(comm, han_module, gather); return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_gather_module); } diff --git a/ompi/mca/coll/han/coll_han_gatherv.c b/ompi/mca/coll/han/coll_han_gatherv.c index 006ab167eba..a80d93d0e92 100644 --- a/ompi/mca/coll/han/coll_han_gatherv.c +++ b/ompi/mca/coll/han/coll_han_gatherv.c @@ -7,6 +7,7 @@ * Copyright (c) 2022 IBM Corporation. All rights reserved * Copyright (c) Amazon.com, Inc. or its affiliates. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -76,7 +77,7 @@ int mca_coll_han_gatherv_intra(const void *sbuf, int scount, struct ompi_datatyp (30, mca_coll_han_component.han_output, "han cannot handle gatherv with this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_gatherv(sbuf, scount, sdtype, rbuf, rcounts, displs, rdtype, root, comm, han_module->previous_gatherv_module); } @@ -91,7 +92,7 @@ int mca_coll_han_gatherv_intra(const void *sbuf, int scount, struct ompi_datatyp /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, gatherv); + HAN_UNINSTALL_COLL_API(comm, han_module, gatherv); return han_module->previous_gatherv(sbuf, scount, sdtype, rbuf, rcounts, displs, rdtype, root, comm, han_module->previous_gatherv_module); } diff --git a/ompi/mca/coll/han/coll_han_module.c b/ompi/mca/coll/han/coll_han_module.c index 31ee2d3fb84..8be8d5a9319 100644 --- a/ompi/mca/coll/han/coll_han_module.c +++ b/ompi/mca/coll/han/coll_han_module.c @@ -7,6 +7,7 @@ * Copyright (c) 2021 Triad National Security, LLC. All rights * reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -29,15 +30,16 @@ /* * Local functions */ -static int han_module_enable(mca_coll_base_module_t * module, - struct ompi_communicator_t *comm); +static int mca_coll_han_module_enable(mca_coll_base_module_t * module, + struct ompi_communicator_t *comm); static int mca_coll_han_module_disable(mca_coll_base_module_t * module, struct ompi_communicator_t *comm); -#define CLEAN_PREV_COLL(HANDLE, NAME) \ - do { \ - (HANDLE)->fallback.NAME.module_fn.NAME = NULL; \ - (HANDLE)->fallback.NAME.module = NULL; \ +#define CLEAN_PREV_COLL(__module, __api) \ + do \ + { \ + (__module)->previous_##__api = NULL; \ + (__module)->previous_##__api##_module = NULL; \ } while (0) /* @@ -71,7 +73,6 @@ static void mca_coll_han_module_construct(mca_coll_han_module_t * module) module->enabled = true; module->recursive_free_depth = 0; - module->super.coll_module_disable = mca_coll_han_module_disable; module->cached_low_comms = NULL; module->cached_up_comms = NULL; module->cached_vranks = NULL; @@ -90,14 +91,6 @@ static void mca_coll_han_module_construct(mca_coll_han_module_t * module) han_module_clear(module); } - -#define OBJ_RELEASE_IF_NOT_NULL(obj) \ - do { \ - if (NULL != (obj)) { \ - OBJ_RELEASE(obj); \ - } \ - } while (0) - /* * Module destructor */ @@ -146,15 +139,6 @@ mca_coll_han_module_destruct(mca_coll_han_module_t * module) } } - OBJ_RELEASE_IF_NOT_NULL(module->previous_allgather_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_allreduce_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_bcast_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_gather_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_gatherv_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_reduce_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(module->previous_scatterv_module); - han_module_clear(module); } @@ -214,13 +198,7 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority) return NULL; } - han_module = OBJ_NEW(mca_coll_han_module_t); - if (NULL == han_module) { - return NULL; - } - - /* All is good -- return a module */ - han_module->topologic_level = GLOBAL_COMMUNICATOR; + int topologic_level = GLOBAL_COMMUNICATOR; if (NULL != comm->super.s_info) { /* Get the info value disaqualifying coll components */ @@ -229,27 +207,31 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority) &info_str, &flag); if (flag) { - if (0 == strcmp(info_str->string, "INTER_NODE")) { - han_module->topologic_level = INTER_NODE; - } else { - han_module->topologic_level = INTRA_NODE; - } + topologic_level = strcmp(info_str->string, "INTER_NODE") ? INTRA_NODE : INTER_NODE; OBJ_RELEASE(info_str); } } if( !ompi_group_have_remote_peers(comm->c_local_group) - && INTRA_NODE != han_module->topologic_level ) { + && INTRA_NODE != topologic_level ) { /* The group only contains local processes, and this is not a * intra-node subcomm we created. Disable HAN for now */ opal_output_verbose(10, ompi_coll_base_framework.framework_output, "coll:han:comm_query (%s/%s): comm has only local processes; disqualifying myself", ompi_comm_print_cid(comm), comm->c_name); - OBJ_RELEASE(han_module); return NULL; } - han_module->super.coll_module_enable = han_module_enable; + /* All is good -- return a module */ + han_module = OBJ_NEW(mca_coll_han_module_t); + if (NULL == han_module) { + return NULL; + } + han_module->topologic_level = topologic_level; + + han_module->super.coll_module_enable = mca_coll_han_module_enable; + han_module->super.coll_module_disable = mca_coll_han_module_disable; + han_module->super.coll_alltoall = NULL; han_module->super.coll_alltoallv = NULL; han_module->super.coll_alltoallw = NULL; @@ -265,7 +247,6 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority) han_module->super.coll_bcast = mca_coll_han_bcast_intra_dynamic; han_module->super.coll_allreduce = mca_coll_han_allreduce_intra_dynamic; han_module->super.coll_allgather = mca_coll_han_allgather_intra_dynamic; - if (GLOBAL_COMMUNICATOR == han_module->topologic_level) { /* We are on the global communicator, return topological algorithms */ han_module->super.coll_allgatherv = NULL; @@ -287,57 +268,48 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority) * . ompi_communicator_t *comm * . mca_coll_han_module_t *han_module */ -#define HAN_SAVE_PREV_COLL_API(__api) \ +#define HAN_INSTALL_COLL_API(__comm, __module, __api) \ do { \ - if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \ - opal_output_verbose(1, ompi_coll_base_framework.framework_output, \ - "(%s/%s): no underlying " # __api"; disqualifying myself", \ - ompi_comm_print_cid(comm), comm->c_name); \ - goto handle_error; \ - } \ - han_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \ - han_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \ - OBJ_RETAIN(han_module->previous_ ## __api ## _module); \ - } while(0) + if( NULL != __module->super.coll_ ## __api ) { \ + if (!__comm->c_coll->coll_##__api || !__comm->c_coll->coll_##__api##_module) { \ + opal_output_verbose(10, ompi_coll_base_framework.framework_output, \ + "(%d/%s): no underlying " #__api " ; disqualifying myself", \ + __comm->c_index, __comm->c_name); \ + } else { \ + MCA_COLL_SAVE_API(__comm, __api, __module->previous_##__api, \ + __module->previous_##__api##_module, "han"); \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "han"); \ + } \ + } \ + } while (0) + +/* The HAN_UNINSTALL_COLL_API is in coll_han.h header as it is needed in several places. */ /* * Init module on the communicator */ static int -han_module_enable(mca_coll_base_module_t * module, - struct ompi_communicator_t *comm) +mca_coll_han_module_enable(mca_coll_base_module_t * module, + struct ompi_communicator_t *comm) { mca_coll_han_module_t * han_module = (mca_coll_han_module_t*) module; - HAN_SAVE_PREV_COLL_API(allgather); - HAN_SAVE_PREV_COLL_API(allgatherv); - HAN_SAVE_PREV_COLL_API(allreduce); - HAN_SAVE_PREV_COLL_API(barrier); - HAN_SAVE_PREV_COLL_API(bcast); - HAN_SAVE_PREV_COLL_API(gather); - HAN_SAVE_PREV_COLL_API(gatherv); - HAN_SAVE_PREV_COLL_API(reduce); - HAN_SAVE_PREV_COLL_API(scatter); - HAN_SAVE_PREV_COLL_API(scatterv); + HAN_INSTALL_COLL_API(comm, han_module, allgather); + HAN_INSTALL_COLL_API(comm, han_module, allgatherv); + HAN_INSTALL_COLL_API(comm, han_module, allreduce); + HAN_INSTALL_COLL_API(comm, han_module, barrier); + HAN_INSTALL_COLL_API(comm, han_module, bcast); + HAN_INSTALL_COLL_API(comm, han_module, gather); + HAN_INSTALL_COLL_API(comm, han_module, gatherv); + HAN_INSTALL_COLL_API(comm, han_module, reduce); + HAN_INSTALL_COLL_API(comm, han_module, scatter); + HAN_INSTALL_COLL_API(comm, han_module, scatterv); /* set reproducible algos */ mca_coll_han_reduce_reproducible_decision(comm, module); mca_coll_han_allreduce_reproducible_decision(comm, module); return OMPI_SUCCESS; - -handle_error: - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allgather_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allgatherv_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allreduce_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_bcast_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gather_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gatherv_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_reduce_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatterv_module); - - return OMPI_ERROR; } /* @@ -349,16 +321,16 @@ mca_coll_han_module_disable(mca_coll_base_module_t * module, { mca_coll_han_module_t * han_module = (mca_coll_han_module_t *) module; - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allgather_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allgatherv_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allreduce_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_barrier_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_bcast_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gather_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gatherv_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_reduce_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatterv_module); + HAN_UNINSTALL_COLL_API(comm, han_module, allgather); + HAN_UNINSTALL_COLL_API(comm, han_module, allgatherv); + HAN_UNINSTALL_COLL_API(comm, han_module, allreduce); + HAN_UNINSTALL_COLL_API(comm, han_module, barrier); + HAN_UNINSTALL_COLL_API(comm, han_module, bcast); + HAN_UNINSTALL_COLL_API(comm, han_module, gather); + HAN_UNINSTALL_COLL_API(comm, han_module, gatherv); + HAN_UNINSTALL_COLL_API(comm, han_module, reduce); + HAN_UNINSTALL_COLL_API(comm, han_module, scatter); + HAN_UNINSTALL_COLL_API(comm, han_module, scatterv); han_module_clear(han_module); diff --git a/ompi/mca/coll/han/coll_han_reduce.c b/ompi/mca/coll/han/coll_han_reduce.c index 53d51dca3e3..4cb5c012d47 100644 --- a/ompi/mca/coll/han/coll_han_reduce.c +++ b/ompi/mca/coll/han/coll_han_reduce.c @@ -6,6 +6,7 @@ * Copyright (c) 2022 IBM Corporation. All rights reserved * Copyright (c) 2024 Computer Architecture and VLSI Systems (CARV) * Laboratory, ICS Forth. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -92,7 +93,7 @@ mca_coll_han_reduce_intra(const void *sbuf, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle reduce with this communicator. Drop HAN support in this communicator and fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all modules */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_reduce(sbuf, rbuf, count, dtype, op, root, comm, han_module->previous_reduce_module); } @@ -106,7 +107,7 @@ mca_coll_han_reduce_intra(const void *sbuf, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, reduce); + HAN_UNINSTALL_COLL_API(comm, han_module, reduce); return han_module->previous_reduce(sbuf, rbuf, count, dtype, op, root, comm, han_module->previous_reduce_module); } @@ -311,7 +312,7 @@ mca_coll_han_reduce_intra_simple(const void *sbuf, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle reduce with this communicator. Drop HAN support in this communicator and fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_reduce(sbuf, rbuf, count, dtype, op, root, comm, han_module->previous_reduce_module); } @@ -325,7 +326,7 @@ mca_coll_han_reduce_intra_simple(const void *sbuf, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, reduce); + HAN_UNINSTALL_COLL_API(comm, han_module, reduce); return han_module->previous_reduce(sbuf, rbuf, count, dtype, op, root, comm, han_module->previous_reduce_module); } diff --git a/ompi/mca/coll/han/coll_han_scatter.c b/ompi/mca/coll/han/coll_han_scatter.c index 87201597bb4..16193d54904 100644 --- a/ompi/mca/coll/han/coll_han_scatter.c +++ b/ompi/mca/coll/han/coll_han_scatter.c @@ -3,6 +3,7 @@ * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -86,7 +87,7 @@ mca_coll_han_scatter_intra(const void *sbuf, int scount, OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle scatter with this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatter_module); } @@ -100,7 +101,7 @@ mca_coll_han_scatter_intra(const void *sbuf, int scount, /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, scatter); + HAN_UNINSTALL_COLL_API(comm, han_module, scatter); return han_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatter_module); } @@ -284,7 +285,7 @@ mca_coll_han_scatter_intra_simple(const void *sbuf, int scount, "han cannot handle allgather within this communicator." " Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatter_module); } @@ -294,7 +295,7 @@ mca_coll_han_scatter_intra_simple(const void *sbuf, int scount, if (han_module->are_ppn_imbalanced){ OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle scatter with this communicator. It needs to fall back on another component\n")); - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_UNINSTALL_COLL_API(comm, han_module, scatter); return han_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatter_module); } diff --git a/ompi/mca/coll/han/coll_han_scatterv.c b/ompi/mca/coll/han/coll_han_scatterv.c index 7cd7f276d3f..16b16754313 100644 --- a/ompi/mca/coll/han/coll_han_scatterv.c +++ b/ompi/mca/coll/han/coll_han_scatterv.c @@ -7,6 +7,7 @@ * Copyright (c) 2022 IBM Corporation. All rights reserved * Copyright (c) Amazon.com, Inc. or its affiliates. * All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -81,7 +82,7 @@ int mca_coll_han_scatterv_intra(const void *sbuf, const int *scounts, const int 30, mca_coll_han_component.han_output, "han cannot handle scatterv with this communicator. Fall back on another component\n")); /* HAN cannot work with this communicator so fallback on all collectives */ - HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); return han_module->previous_scatterv(sbuf, scounts, displs, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatterv_module); } @@ -96,7 +97,7 @@ int mca_coll_han_scatterv_intra(const void *sbuf, const int *scounts, const int /* Put back the fallback collective support and call it once. All * future calls will then be automatically redirected. */ - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, scatterv); + HAN_UNINSTALL_COLL_API(comm, han_module, scatterv); return han_module->previous_scatterv(sbuf, scounts, displs, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatterv_module); } @@ -104,7 +105,7 @@ int mca_coll_han_scatterv_intra(const void *sbuf, const int *scounts, const int OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "han cannot handle scatterv with this communicator (heterogeneous). Fall " "back on another component\n")); - HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, scatterv); + HAN_UNINSTALL_COLL_API(comm, han_module, scatterv); return han_module->previous_scatterv(sbuf, scounts, displs, sdtype, rbuf, rcount, rdtype, root, comm, han_module->previous_scatterv_module); } diff --git a/ompi/mca/coll/han/coll_han_subcomms.c b/ompi/mca/coll/han/coll_han_subcomms.c index 47ef348975e..92bddb3ba51 100644 --- a/ompi/mca/coll/han/coll_han_subcomms.c +++ b/ompi/mca/coll/han/coll_han_subcomms.c @@ -3,7 +3,10 @@ * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2020 Bull S.A.S. All rights reserved. + * Copyright (c) 2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. * + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -26,19 +29,21 @@ #include "coll_han.h" #include "coll_han_dynamic.h" -#define HAN_SUBCOM_SAVE_COLLECTIVE(FALLBACKS, COMM, HANM, COLL) \ - do { \ - (FALLBACKS).COLL.module_fn.COLL = (COMM)->c_coll->coll_ ## COLL; \ - (FALLBACKS).COLL.module = (COMM)->c_coll->coll_ ## COLL ## _module; \ - (COMM)->c_coll->coll_ ## COLL = (HANM)->fallback.COLL.module_fn.COLL; \ - (COMM)->c_coll->coll_ ## COLL ## _module = (HANM)->fallback.COLL.module; \ - } while(0) - -#define HAN_SUBCOM_LOAD_COLLECTIVE(FALLBACKS, COMM, HANM, COLL) \ - do { \ - (COMM)->c_coll->coll_ ## COLL = (FALLBACKS).COLL.module_fn.COLL; \ - (COMM)->c_coll->coll_ ## COLL ## _module = (FALLBACKS).COLL.module; \ - } while(0) +#define HAN_SUBCOM_SAVE_COLLECTIVE(FALLBACKS, COMM, HANM, COLL) \ + do \ + { \ + (FALLBACKS).COLL.COLL = (COMM)->c_coll->coll_##COLL; \ + (FALLBACKS).COLL.module = (COMM)->c_coll->coll_##COLL##_module; \ + (COMM)->c_coll->coll_##COLL = (HANM)->fallback.COLL.COLL; \ + (COMM)->c_coll->coll_##COLL##_module = (HANM)->fallback.COLL.module; \ + } while (0) + +#define HAN_SUBCOM_RESTORE_COLLECTIVE(FALLBACKS, COMM, HANM, COLL) \ + do \ + { \ + (COMM)->c_coll->coll_##COLL = (FALLBACKS).COLL.COLL; \ + (COMM)->c_coll->coll_##COLL##_module = (FALLBACKS).COLL.module; \ + } while (0) /* * Routine that creates the local hierarchical sub-communicators @@ -64,8 +69,8 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, /* * We cannot use han allreduce and allgather without sub-communicators, - * but we are in the creation of the data structures for the HAN, and - * temporarily need to save back the old collective. + * but we are in the creation of the data structures for HAN, and + * temporarily need to use the old collective. * * Allgather is used to compute vranks * Allreduce is used by ompi_comm_split_type in create_intranode_comm_new @@ -90,7 +95,8 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, * outside the MPI support (with PRRTE the info will be eventually available, * but we don't want to delay anything until then). We can achieve the same * goal by using a reduction over the maximum number of peers per node among - * all participants. + * all participants, but we need to call the fallback allreduce (or we will + * call HAN's allreduce again). */ int local_procs = ompi_group_count_local_peers(comm->c_local_group); rc = comm->c_coll->coll_allreduce(MPI_IN_PLACE, &local_procs, 1, MPI_INT, @@ -100,18 +106,18 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, goto return_with_error; } if( local_procs == 1 ) { - /* restore saved collectives */ - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allreduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, bcast); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, reduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatterv); han_module->enabled = false; /* entire module set to pass-through from now on */ - return OMPI_ERR_NOT_SUPPORTED; + /* restore saved collectives */ + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allreduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, bcast); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, reduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatterv); + return OMPI_ERR_NOT_SUPPORTED; } OBJ_CONSTRUCT(&comm_info, opal_info_t); @@ -178,21 +184,22 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, */ han_module->cached_vranks = vranks; - /* Reset the saved collectives to point back to HAN */ - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allreduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, bcast); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, reduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatterv); + /* Restore the saved collectives */ + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allreduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, bcast); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, reduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatterv); OBJ_DESTRUCT(&comm_info); return OMPI_SUCCESS; return_with_error: + han_module->enabled = false; /* entire module set to pass-through from now on */ if( NULL != *low_comm ) { ompi_comm_free(low_comm); *low_comm = NULL; /* don't leave the MPI_COMM_NULL set by ompi_comm_free */ @@ -229,7 +236,7 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, /* * We cannot use han allreduce and allgather without sub-communicators, * but we are in the creation of the data structures for the HAN, and - * temporarily need to save back the old collective. + * temporarily need to use the old collective. * * Allgather is used to compute vranks * Allreduce is used by ompi_comm_split_type in create_intranode_comm_new @@ -262,15 +269,15 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, comm->c_coll->coll_allreduce_module); if( local_procs == 1 ) { /* restore saved collectives */ - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allreduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, bcast); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, reduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatterv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allreduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, bcast); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, reduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatterv); han_module->enabled = false; /* entire module set to pass-through from now on */ return OMPI_ERR_NOT_SUPPORTED; } @@ -309,6 +316,10 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, &comm_info, &(low_comms[1])); assert(OMPI_COMM_IS_DISJOINT_SET(low_comms[1]) && !OMPI_COMM_IS_DISJOINT(low_comms[1])); + opal_info_set(&comm_info, "ompi_comm_coll_preference", "xhc,^han"); + ompi_comm_split_type(comm, MPI_COMM_TYPE_SHARED, 0, + &comm_info, &(low_comms[2])); + /* * Upgrade libnbc module priority to set up up_comms[0] with libnbc module * This sub-communicator contains one process per node: processes with the @@ -352,15 +363,15 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, han_module->cached_vranks = vranks; /* Reset the saved collectives to point back to HAN */ - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allgather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, allreduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, bcast); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, reduce); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); - HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatterv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allreduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, bcast); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, reduce); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, gatherv); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, scatterv); OBJ_DESTRUCT(&comm_info); return OMPI_SUCCESS; diff --git a/ompi/mca/coll/hcoll/coll_hcoll_module.c b/ompi/mca/coll/hcoll/coll_hcoll_module.c index 0b8c2db4d48..5ca588a8154 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_module.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_module.c @@ -7,6 +7,7 @@ * Copyright (c) 2018 Cisco Systems, Inc. All rights reserved * Copyright (c) 2022 Amazon.com, Inc. or its affiliates. * All Rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -86,8 +87,6 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module) hcoll_module->previous_igatherv_module = NULL; hcoll_module->previous_ialltoall_module = NULL; hcoll_module->previous_ialltoallv_module = NULL; - - } static void mca_coll_hcoll_module_construct(mca_coll_hcoll_module_t *hcoll_module) @@ -101,8 +100,6 @@ void mca_coll_hcoll_mem_release_cb(void *buf, size_t length, hcoll_mem_unmap(buf, length, cbdata, from_alloc); } -#define OBJ_RELEASE_IF_NOT_NULL( obj ) if( NULL != (obj) ) OBJ_RELEASE( obj ); - static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module) { int context_destroyed; @@ -119,37 +116,7 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module destroy hcoll context*/ if (hcoll_module->hcoll_context != NULL){ - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_barrier_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_bcast_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allreduce_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_block_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgather_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgatherv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_gatherv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_scatterv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_alltoall_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_alltoallv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_module); - - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_ibarrier_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_ibcast_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_iallreduce_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_iallgather_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_iallgatherv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_igatherv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_ialltoall_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_ialltoallv_module); - OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_ireduce_module); - - /* - OBJ_RELEASE(hcoll_module->previous_allgatherv_module); - OBJ_RELEASE(hcoll_module->previous_gather_module); - OBJ_RELEASE(hcoll_module->previous_gatherv_module); - OBJ_RELEASE(hcoll_module->previous_alltoallw_module); - OBJ_RELEASE(hcoll_module->previous_reduce_scatter_module); - OBJ_RELEASE(hcoll_module->previous_reduce_module); - */ + #if !defined(HAVE_HCOLL_CONTEXT_FREE) context_destroyed = 0; hcoll_destroy_context(hcoll_module->hcoll_context, @@ -160,52 +127,105 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module mca_coll_hcoll_module_clear(hcoll_module); } -#define HCOL_SAVE_PREV_COLL_API(__api) do {\ - hcoll_module->previous_ ## __api = comm->c_coll->coll_ ## __api;\ - hcoll_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module;\ - if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) {\ - return OMPI_ERROR;\ - }\ - OBJ_RETAIN(hcoll_module->previous_ ## __api ## _module);\ -} while(0) - +#define HCOL_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (NULL != __module->super.coll_##__api) \ + { \ + if (comm->c_coll->coll_##__api && !comm->c_coll->coll_##__api##_module) \ + { \ + /* save the current selected collective */ \ + MCA_COLL_SAVE_API(__comm, __api, hcoll_module->previous_##__api, hcoll_module->previous_##__api##_module, "hcoll"); \ + /* install our own */ \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "hcoll"); \ + } \ + } \ + } while (0) + +#define HCOL_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (&__module->super == comm->c_coll->coll_##__api##_module) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "hcoll"); \ + hcoll_module->previous_##__api = NULL; \ + hcoll_module->previous_##__api##_module = NULL; \ + } \ + } while (0) static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_module) { ompi_communicator_t *comm; comm = hcoll_module->comm; - HCOL_SAVE_PREV_COLL_API(barrier); - HCOL_SAVE_PREV_COLL_API(bcast); - HCOL_SAVE_PREV_COLL_API(allreduce); - HCOL_SAVE_PREV_COLL_API(reduce_scatter_block); - HCOL_SAVE_PREV_COLL_API(reduce_scatter); - HCOL_SAVE_PREV_COLL_API(reduce); - HCOL_SAVE_PREV_COLL_API(allgather); - HCOL_SAVE_PREV_COLL_API(allgatherv); - HCOL_SAVE_PREV_COLL_API(gatherv); - HCOL_SAVE_PREV_COLL_API(scatterv); - HCOL_SAVE_PREV_COLL_API(alltoall); - HCOL_SAVE_PREV_COLL_API(alltoallv); - - HCOL_SAVE_PREV_COLL_API(ibarrier); - HCOL_SAVE_PREV_COLL_API(ibcast); - HCOL_SAVE_PREV_COLL_API(iallreduce); - HCOL_SAVE_PREV_COLL_API(ireduce); - HCOL_SAVE_PREV_COLL_API(iallgather); - HCOL_SAVE_PREV_COLL_API(iallgatherv); - HCOL_SAVE_PREV_COLL_API(igatherv); - HCOL_SAVE_PREV_COLL_API(ialltoall); - HCOL_SAVE_PREV_COLL_API(ialltoallv); + hcoll_module->super.coll_barrier = hcoll_collectives.coll_barrier ? mca_coll_hcoll_barrier : NULL; + hcoll_module->super.coll_bcast = hcoll_collectives.coll_bcast ? mca_coll_hcoll_bcast : NULL; + hcoll_module->super.coll_allgather = hcoll_collectives.coll_allgather ? mca_coll_hcoll_allgather : NULL; + hcoll_module->super.coll_allgatherv = hcoll_collectives.coll_allgatherv ? mca_coll_hcoll_allgatherv : NULL; + hcoll_module->super.coll_allreduce = hcoll_collectives.coll_allreduce ? mca_coll_hcoll_allreduce : NULL; + hcoll_module->super.coll_alltoall = hcoll_collectives.coll_alltoall ? mca_coll_hcoll_alltoall : NULL; + hcoll_module->super.coll_alltoallv = hcoll_collectives.coll_alltoallv ? mca_coll_hcoll_alltoallv : NULL; + hcoll_module->super.coll_gatherv = hcoll_collectives.coll_gatherv ? mca_coll_hcoll_gatherv : NULL; + hcoll_module->super.coll_scatterv = hcoll_collectives.coll_scatterv ? mca_coll_hcoll_scatterv : NULL; + hcoll_module->super.coll_reduce = hcoll_collectives.coll_reduce ? mca_coll_hcoll_reduce : NULL; + hcoll_module->super.coll_ibarrier = hcoll_collectives.coll_ibarrier ? mca_coll_hcoll_ibarrier : NULL; + hcoll_module->super.coll_ibcast = hcoll_collectives.coll_ibcast ? mca_coll_hcoll_ibcast : NULL; + hcoll_module->super.coll_iallgather = hcoll_collectives.coll_iallgather ? mca_coll_hcoll_iallgather : NULL; +#if HCOLL_API >= HCOLL_VERSION(3, 5) + hcoll_module->super.coll_iallgatherv = hcoll_collectives.coll_iallgatherv ? mca_coll_hcoll_iallgatherv : NULL; +#else + hcoll_module->super.coll_iallgatherv = NULL; +#endif + hcoll_module->super.coll_iallreduce = hcoll_collectives.coll_iallreduce ? mca_coll_hcoll_iallreduce : NULL; +#if HCOLL_API >= HCOLL_VERSION(3, 5) + hcoll_module->super.coll_ireduce = hcoll_collectives.coll_ireduce ? mca_coll_hcoll_ireduce : NULL; +#else + hcoll_module->super.coll_ireduce = NULL; +#endif + hcoll_module->super.coll_gather = /*hcoll_collectives.coll_gather ? mca_coll_hcoll_gather :*/ NULL; + hcoll_module->super.coll_igatherv = hcoll_collectives.coll_igatherv ? mca_coll_hcoll_igatherv : NULL; + hcoll_module->super.coll_ialltoall = /*hcoll_collectives.coll_ialltoall ? mca_coll_hcoll_ialltoall : */ NULL; +#if HCOLL_API >= HCOLL_VERSION(3, 7) + hcoll_module->super.coll_ialltoallv = hcoll_collectives.coll_ialltoallv ? mca_coll_hcoll_ialltoallv : NULL; +#else + hcoll_module->super.coll_ialltoallv = NULL; +#endif +#if HCOLL_API > HCOLL_VERSION(4, 5) + hcoll_module->super.coll_reduce_scatter_block = hcoll_collectives.coll_reduce_scatter_block ? mca_coll_hcoll_reduce_scatter_block : NULL; + hcoll_module->super.coll_reduce_scatter = hcoll_collectives.coll_reduce_scatter ? mca_coll_hcoll_reduce_scatter : NULL; +#endif + + HCOL_INSTALL_COLL_API(comm, hcoll_module, barrier); + HCOL_INSTALL_COLL_API(comm, hcoll_module, bcast); + HCOL_INSTALL_COLL_API(comm, hcoll_module, allreduce); + HCOL_INSTALL_COLL_API(comm, hcoll_module, reduce_scatter_block); + HCOL_INSTALL_COLL_API(comm, hcoll_module, reduce_scatter); + HCOL_INSTALL_COLL_API(comm, hcoll_module, reduce); + HCOL_INSTALL_COLL_API(comm, hcoll_module, allgather); + HCOL_INSTALL_COLL_API(comm, hcoll_module, allgatherv); + HCOL_INSTALL_COLL_API(comm, hcoll_module, gatherv); + HCOL_INSTALL_COLL_API(comm, hcoll_module, scatterv); + HCOL_INSTALL_COLL_API(comm, hcoll_module, alltoall); + HCOL_INSTALL_COLL_API(comm, hcoll_module, alltoallv); + + HCOL_INSTALL_COLL_API(comm, hcoll_module, ibarrier); + HCOL_INSTALL_COLL_API(comm, hcoll_module, ibcast); + HCOL_INSTALL_COLL_API(comm, hcoll_module, iallreduce); + HCOL_INSTALL_COLL_API(comm, hcoll_module, ireduce); + HCOL_INSTALL_COLL_API(comm, hcoll_module, iallgather); + HCOL_INSTALL_COLL_API(comm, hcoll_module, iallgatherv); + HCOL_INSTALL_COLL_API(comm, hcoll_module, igatherv); + HCOL_INSTALL_COLL_API(comm, hcoll_module, ialltoall); + HCOL_INSTALL_COLL_API(comm, hcoll_module, ialltoallv); /* These collectives are not yet part of hcoll, so don't retain them on hcoll module - HCOL_SAVE_PREV_COLL_API(reduce_scatter); - HCOL_SAVE_PREV_COLL_API(gather); - HCOL_SAVE_PREV_COLL_API(reduce); - HCOL_SAVE_PREV_COLL_API(allgatherv); - HCOL_SAVE_PREV_COLL_API(alltoallw); + HCOL_INSTALL_COLL_API(comm, hcoll_module, reduce_scatter); + HCOL_INSTALL_COLL_API(comm, hcoll_module, gather); + HCOL_INSTALL_COLL_API(comm, hcoll_module, reduce); + HCOL_INSTALL_COLL_API(comm, hcoll_module, allgatherv); + HCOL_INSTALL_COLL_API(comm, hcoll_module, alltoallw); */ return OMPI_SUCCESS; } @@ -251,6 +271,45 @@ static int mca_coll_hcoll_module_enable(mca_coll_base_module_t *module, return OMPI_SUCCESS; } +static int mca_coll_hcoll_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t *)module; + + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, barrier); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, bcast); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, allreduce); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, reduce_scatter_block); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, reduce_scatter); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, reduce); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, allgather); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, allgatherv); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, gatherv); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, scatterv); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, alltoall); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, alltoallv); + + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, ibarrier); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, ibcast); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, iallreduce); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, ireduce); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, iallgather); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, iallgatherv); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, igatherv); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, ialltoall); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, ialltoallv); + + /* + These collectives are not yet part of hcoll, so + don't retain them on hcoll module + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, reduce_scatter); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, gather); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, reduce); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, allgatherv); + HCOL_UNINSTALL_COLL_API(comm, hcoll_module, alltoallw); + */ + return OMPI_SUCCESS; +} OBJ_CLASS_INSTANCE(mca_coll_hcoll_dtype_t, opal_free_list_item_t, @@ -395,44 +454,8 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority) } hcoll_module->super.coll_module_enable = mca_coll_hcoll_module_enable; - hcoll_module->super.coll_barrier = hcoll_collectives.coll_barrier ? mca_coll_hcoll_barrier : NULL; - hcoll_module->super.coll_bcast = hcoll_collectives.coll_bcast ? mca_coll_hcoll_bcast : NULL; - hcoll_module->super.coll_allgather = hcoll_collectives.coll_allgather ? mca_coll_hcoll_allgather : NULL; - hcoll_module->super.coll_allgatherv = hcoll_collectives.coll_allgatherv ? mca_coll_hcoll_allgatherv : NULL; - hcoll_module->super.coll_allreduce = hcoll_collectives.coll_allreduce ? mca_coll_hcoll_allreduce : NULL; - hcoll_module->super.coll_alltoall = hcoll_collectives.coll_alltoall ? mca_coll_hcoll_alltoall : NULL; - hcoll_module->super.coll_alltoallv = hcoll_collectives.coll_alltoallv ? mca_coll_hcoll_alltoallv : NULL; - hcoll_module->super.coll_gatherv = hcoll_collectives.coll_gatherv ? mca_coll_hcoll_gatherv : NULL; - hcoll_module->super.coll_scatterv = hcoll_collectives.coll_scatterv ? mca_coll_hcoll_scatterv : NULL; - hcoll_module->super.coll_reduce = hcoll_collectives.coll_reduce ? mca_coll_hcoll_reduce : NULL; - hcoll_module->super.coll_ibarrier = hcoll_collectives.coll_ibarrier ? mca_coll_hcoll_ibarrier : NULL; - hcoll_module->super.coll_ibcast = hcoll_collectives.coll_ibcast ? mca_coll_hcoll_ibcast : NULL; - hcoll_module->super.coll_iallgather = hcoll_collectives.coll_iallgather ? mca_coll_hcoll_iallgather : NULL; -#if HCOLL_API >= HCOLL_VERSION(3,5) - hcoll_module->super.coll_iallgatherv = hcoll_collectives.coll_iallgatherv ? mca_coll_hcoll_iallgatherv : NULL; -#else - hcoll_module->super.coll_iallgatherv = NULL; -#endif - hcoll_module->super.coll_iallreduce = hcoll_collectives.coll_iallreduce ? mca_coll_hcoll_iallreduce : NULL; -#if HCOLL_API >= HCOLL_VERSION(3,5) - hcoll_module->super.coll_ireduce = hcoll_collectives.coll_ireduce ? mca_coll_hcoll_ireduce : NULL; -#else - hcoll_module->super.coll_ireduce = NULL; -#endif - hcoll_module->super.coll_gather = /*hcoll_collectives.coll_gather ? mca_coll_hcoll_gather :*/ NULL; - hcoll_module->super.coll_igatherv = hcoll_collectives.coll_igatherv ? mca_coll_hcoll_igatherv : NULL; - hcoll_module->super.coll_ialltoall = /*hcoll_collectives.coll_ialltoall ? mca_coll_hcoll_ialltoall : */ NULL; -#if HCOLL_API >= HCOLL_VERSION(3,7) - hcoll_module->super.coll_ialltoallv = hcoll_collectives.coll_ialltoallv ? mca_coll_hcoll_ialltoallv : NULL; -#else - hcoll_module->super.coll_ialltoallv = NULL; -#endif -#if HCOLL_API > HCOLL_VERSION(4,5) - hcoll_module->super.coll_reduce_scatter_block = hcoll_collectives.coll_reduce_scatter_block ? - mca_coll_hcoll_reduce_scatter_block : NULL; - hcoll_module->super.coll_reduce_scatter = hcoll_collectives.coll_reduce_scatter ? - mca_coll_hcoll_reduce_scatter : NULL; -#endif + hcoll_module->super.coll_module_disable = mca_coll_hcoll_module_disable; + *priority = cm->hcoll_priority; module = &hcoll_module->super; diff --git a/ompi/mca/coll/inter/Makefile.am b/ompi/mca/coll/inter/Makefile.am index d9c691cf458..ea9188cffd4 100644 --- a/ompi/mca/coll/inter/Makefile.am +++ b/ompi/mca/coll/inter/Makefile.am @@ -11,6 +11,7 @@ # All rights reserved. # Copyright (c) 2010 Cisco Systems, Inc. All rights reserved. # Copyright (c) 2017 IBM Corporation. All rights reserved. +# Copyright (c) 2024 NVIDIA Corporation. All rights reserved. # $COPYRIGHT$ # # Additional copyrights may follow @@ -43,7 +44,7 @@ libmca_coll_inter_la_LDFLAGS = -module -avoid-version sources = \ coll_inter.h \ - coll_inter.c \ + coll_inter_module.c \ coll_inter_allreduce.c \ coll_inter_allgather.c \ coll_inter_allgatherv.c \ diff --git a/ompi/mca/coll/inter/coll_inter.h b/ompi/mca/coll/inter/coll_inter.h index 09bb78c5c79..17dc055451c 100644 --- a/ompi/mca/coll/inter/coll_inter.h +++ b/ompi/mca/coll/inter/coll_inter.h @@ -13,6 +13,7 @@ * Copyright (c) 2008 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2015 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -50,9 +51,6 @@ int mca_coll_inter_init_query(bool allow_inter_user_threads, mca_coll_base_module_t * mca_coll_inter_comm_query(struct ompi_communicator_t *comm, int *priority); -int mca_coll_inter_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - int mca_coll_inter_allgather_inter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, diff --git a/ompi/mca/coll/inter/coll_inter_allgather.c b/ompi/mca/coll/inter/coll_inter_allgather.c index fe867cda06a..59cd19ff75f 100644 --- a/ompi/mca/coll/inter/coll_inter_allgather.c +++ b/ompi/mca/coll/inter/coll_inter_allgather.c @@ -13,6 +13,7 @@ * Copyright (c) 2015-2017 Research Organization for Information Science * and Technology (RIST). All rights reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -61,23 +62,23 @@ mca_coll_inter_allgather_inter(const void *sbuf, int scount, /* Perform the gather locally at the root */ if ( scount > 0 ) { span = opal_datatype_span(&sdtype->super, (int64_t)scount*(int64_t)size, &gap); - ptmp_free = (char*)malloc(span); - if (NULL == ptmp_free) { - return OMPI_ERR_OUT_OF_RESOURCE; - } + ptmp_free = (char*)malloc(span); + if (NULL == ptmp_free) { + return OMPI_ERR_OUT_OF_RESOURCE; + } ptmp = ptmp_free - gap; - err = comm->c_local_comm->c_coll->coll_gather(sbuf, scount, sdtype, + err = comm->c_local_comm->c_coll->coll_gather(sbuf, scount, sdtype, ptmp, scount, sdtype, 0, comm->c_local_comm, comm->c_local_comm->c_coll->coll_gather_module); - if (OMPI_SUCCESS != err) { - goto exit; - } + if (OMPI_SUCCESS != err) { + goto exit; + } } if (rank == root) { - /* Do a send-recv between the two root procs. to avoid deadlock */ + /* Do a send-recv between the two root procs. to avoid deadlock */ err = ompi_coll_base_sendrecv_actual(ptmp, scount*(size_t)size, sdtype, 0, MCA_COLL_BASE_TAG_ALLGATHER, rbuf, rcount*(size_t)rsize, rdtype, 0, diff --git a/ompi/mca/coll/inter/coll_inter_component.c b/ompi/mca/coll/inter/coll_inter_component.c index 39ff40c9cfb..76f7340b1be 100644 --- a/ompi/mca/coll/inter/coll_inter_component.c +++ b/ompi/mca/coll/inter/coll_inter_component.c @@ -14,6 +14,7 @@ * Copyright (c) 2008 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2015 Los Alamos National Security, LLC. All rights * reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -113,14 +114,7 @@ mca_coll_inter_module_construct(mca_coll_inter_module_t *module) module->inter_comm = NULL; } -static void -mca_coll_inter_module_destruct(mca_coll_inter_module_t *module) -{ - -} - - OBJ_CLASS_INSTANCE(mca_coll_inter_module_t, mca_coll_base_module_t, mca_coll_inter_module_construct, - mca_coll_inter_module_destruct); + NULL); diff --git a/ompi/mca/coll/inter/coll_inter.c b/ompi/mca/coll/inter/coll_inter_module.c similarity index 63% rename from ompi/mca/coll/inter/coll_inter.c rename to ompi/mca/coll/inter/coll_inter_module.c index d75994e396a..2dd35c10fa2 100644 --- a/ompi/mca/coll/inter/coll_inter.c +++ b/ompi/mca/coll/inter/coll_inter_module.c @@ -13,6 +13,7 @@ * Copyright (c) 2008 Sun Microsystems, Inc. All rights reserved. * Copyright (c) 2013 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2017 IBM Corporation. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -37,7 +38,6 @@ #if 0 -static void mca_coll_inter_dump_struct ( struct mca_coll_base_comm_t *c); static const mca_coll_base_module_1_0_0_t inter = { @@ -68,6 +68,12 @@ static const mca_coll_base_module_1_0_0_t inter = { }; #endif +static int +mca_coll_inter_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); +static int +mca_coll_inter_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); /* * Initial query function that is invoked during MPI_INIT, allowing @@ -101,22 +107,23 @@ mca_coll_inter_comm_query(struct ompi_communicator_t *comm, int *priority) * than or equal to 0, then the module is unavailable. */ *priority = mca_coll_inter_priority_param; if (0 >= mca_coll_inter_priority_param) { - return NULL; + return NULL; } size = ompi_comm_size(comm); rsize = ompi_comm_remote_size(comm); if ( size < mca_coll_inter_crossover && rsize < mca_coll_inter_crossover) { - return NULL; + return NULL; } inter_module = OBJ_NEW(mca_coll_inter_module_t); if (NULL == inter_module) { - return NULL; + return NULL; } inter_module->super.coll_module_enable = mca_coll_inter_module_enable; + inter_module->super.coll_module_disable = mca_coll_inter_module_disable; inter_module->super.coll_allgather = mca_coll_inter_allgather_inter; inter_module->super.coll_allgatherv = mca_coll_inter_allgatherv_inter; @@ -139,11 +146,24 @@ mca_coll_inter_comm_query(struct ompi_communicator_t *comm, int *priority) return &(inter_module->super); } - +#define INTER_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "inter"); \ + } while (0) + +#define INTER_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "inter"); \ + } \ + } while (0) /* * Init module on the communicator */ -int +static int mca_coll_inter_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm) { @@ -151,27 +171,35 @@ mca_coll_inter_module_enable(mca_coll_base_module_t *module, inter_module->inter_comm = comm; -#if 0 - if ( mca_coll_inter_verbose_param ) { - mca_coll_inter_dump_struct (data); - } -#endif - + INTER_INSTALL_COLL_API(comm, inter_module, allgather); + INTER_INSTALL_COLL_API(comm, inter_module, allgatherv); + INTER_INSTALL_COLL_API(comm, inter_module, allreduce); + INTER_INSTALL_COLL_API(comm, inter_module, bcast); + INTER_INSTALL_COLL_API(comm, inter_module, gather); + INTER_INSTALL_COLL_API(comm, inter_module, gatherv); + INTER_INSTALL_COLL_API(comm, inter_module, reduce); + INTER_INSTALL_COLL_API(comm, inter_module, scatter); + INTER_INSTALL_COLL_API(comm, inter_module, scatterv); return OMPI_SUCCESS; } - -#if 0 -static void mca_coll_inter_dump_struct ( struct mca_coll_base_comm_t *c) +static int +mca_coll_inter_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) { - int rank; - - rank = ompi_comm_rank ( c->inter_comm ); + mca_coll_inter_module_t *inter_module = (mca_coll_inter_module_t*) module; - printf("%d: Dump of inter-struct for comm %s cid %u\n", - rank, c->inter_comm->c_name, c->inter_comm->c_contextid); + INTER_UNINSTALL_COLL_API(comm, inter_module, allgather); + INTER_UNINSTALL_COLL_API(comm, inter_module, allgatherv); + INTER_UNINSTALL_COLL_API(comm, inter_module, allreduce); + INTER_UNINSTALL_COLL_API(comm, inter_module, bcast); + INTER_UNINSTALL_COLL_API(comm, inter_module, gather); + INTER_UNINSTALL_COLL_API(comm, inter_module, gatherv); + INTER_UNINSTALL_COLL_API(comm, inter_module, reduce); + INTER_UNINSTALL_COLL_API(comm, inter_module, scatter); + INTER_UNINSTALL_COLL_API(comm, inter_module, scatterv); + inter_module->inter_comm = NULL; - return; + return OMPI_SUCCESS; } -#endif diff --git a/ompi/mca/coll/libnbc/coll_libnbc_component.c b/ompi/mca/coll/libnbc/coll_libnbc_component.c index 3b9662ea682..803b3ae46b1 100644 --- a/ompi/mca/coll/libnbc/coll_libnbc_component.c +++ b/ompi/mca/coll/libnbc/coll_libnbc_component.c @@ -19,6 +19,7 @@ * Copyright (c) 2017 Ian Bradley Morgan and Anthony Skjellum. All * rights reserved. * Copyright (c) 2018 FUJITSU LIMITED. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -30,6 +31,7 @@ #include "coll_libnbc.h" #include "nbc_internal.h" +#include "ompi/mca/coll/base/base.h" #include "mpi.h" #include "ompi/mca/coll/coll.h" @@ -106,6 +108,7 @@ static int libnbc_register(void); static int libnbc_init_query(bool, bool); static mca_coll_base_module_t *libnbc_comm_query(struct ompi_communicator_t *, int *); static int libnbc_module_enable(mca_coll_base_module_t *, struct ompi_communicator_t *); +static int libnbc_module_disable(mca_coll_base_module_t *, struct ompi_communicator_t *); /* * Instantiate the public struct with all of our public information @@ -313,6 +316,7 @@ libnbc_comm_query(struct ompi_communicator_t *comm, *priority = libnbc_priority; module->super.coll_module_enable = libnbc_module_enable; + module->super.coll_module_disable = libnbc_module_disable; if (OMPI_COMM_IS_INTER(comm)) { module->super.coll_iallgather = ompi_coll_libnbc_iallgather_inter; module->super.coll_iallgatherv = ompi_coll_libnbc_iallgatherv_inter; @@ -407,7 +411,23 @@ libnbc_comm_query(struct ompi_communicator_t *comm, return &(module->super); } - +#define NBC_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__module->super.coll_##__api) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "libnbc"); \ + } \ + } while (0) + +#define NBC_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "libnbc"); \ + } \ + } while (0) /* * Init module on the communicator */ @@ -415,10 +435,120 @@ static int libnbc_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm) { - /* All done */ + ompi_coll_libnbc_module_t *nbc_module = (ompi_coll_libnbc_module_t*)module; + + NBC_INSTALL_COLL_API(comm, nbc_module, iallgather); + NBC_INSTALL_COLL_API(comm, nbc_module, iallgatherv); + NBC_INSTALL_COLL_API(comm, nbc_module, iallreduce); + NBC_INSTALL_COLL_API(comm, nbc_module, ialltoall); + NBC_INSTALL_COLL_API(comm, nbc_module, ialltoallv); + NBC_INSTALL_COLL_API(comm, nbc_module, ialltoallw); + NBC_INSTALL_COLL_API(comm, nbc_module, ibarrier); + NBC_INSTALL_COLL_API(comm, nbc_module, ibcast); + NBC_INSTALL_COLL_API(comm, nbc_module, igather); + NBC_INSTALL_COLL_API(comm, nbc_module, igatherv); + NBC_INSTALL_COLL_API(comm, nbc_module, ireduce); + NBC_INSTALL_COLL_API(comm, nbc_module, ireduce_scatter); + NBC_INSTALL_COLL_API(comm, nbc_module, ireduce_scatter_block); + NBC_INSTALL_COLL_API(comm, nbc_module, iscatter); + NBC_INSTALL_COLL_API(comm, nbc_module, iscatterv); + + NBC_INSTALL_COLL_API(comm, nbc_module, allgather_init); + NBC_INSTALL_COLL_API(comm, nbc_module, allgatherv_init); + NBC_INSTALL_COLL_API(comm, nbc_module, allreduce_init); + NBC_INSTALL_COLL_API(comm, nbc_module, alltoall_init); + NBC_INSTALL_COLL_API(comm, nbc_module, alltoallv_init); + NBC_INSTALL_COLL_API(comm, nbc_module, alltoallw_init); + NBC_INSTALL_COLL_API(comm, nbc_module, barrier_init); + NBC_INSTALL_COLL_API(comm, nbc_module, bcast_init); + NBC_INSTALL_COLL_API(comm, nbc_module, gather_init); + NBC_INSTALL_COLL_API(comm, nbc_module, gatherv_init); + NBC_INSTALL_COLL_API(comm, nbc_module, reduce_init); + NBC_INSTALL_COLL_API(comm, nbc_module, reduce_scatter_init); + NBC_INSTALL_COLL_API(comm, nbc_module, reduce_scatter_block_init); + NBC_INSTALL_COLL_API(comm, nbc_module, scatter_init); + NBC_INSTALL_COLL_API(comm, nbc_module, scatterv_init); + + if (!OMPI_COMM_IS_INTER(comm)) { + NBC_INSTALL_COLL_API(comm, nbc_module, exscan_init); + NBC_INSTALL_COLL_API(comm, nbc_module, iexscan); + NBC_INSTALL_COLL_API(comm, nbc_module, scan_init); + NBC_INSTALL_COLL_API(comm, nbc_module, iscan); + + NBC_INSTALL_COLL_API(comm, nbc_module, ineighbor_allgather); + NBC_INSTALL_COLL_API(comm, nbc_module, ineighbor_allgatherv); + NBC_INSTALL_COLL_API(comm, nbc_module, ineighbor_alltoall); + NBC_INSTALL_COLL_API(comm, nbc_module, ineighbor_alltoallv); + NBC_INSTALL_COLL_API(comm, nbc_module, ineighbor_alltoallw); + + NBC_INSTALL_COLL_API(comm, nbc_module, neighbor_allgather_init); + NBC_INSTALL_COLL_API(comm, nbc_module, neighbor_allgatherv_init); + NBC_INSTALL_COLL_API(comm, nbc_module, neighbor_alltoall_init); + NBC_INSTALL_COLL_API(comm, nbc_module, neighbor_alltoallv_init); + NBC_INSTALL_COLL_API(comm, nbc_module, neighbor_alltoallw_init); + } /* All done */ + return OMPI_SUCCESS; } +static int +libnbc_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + ompi_coll_libnbc_module_t *nbc_module = (ompi_coll_libnbc_module_t*)module; + + NBC_UNINSTALL_COLL_API(comm, nbc_module, iallgather); + NBC_UNINSTALL_COLL_API(comm, nbc_module, iallgatherv); + NBC_UNINSTALL_COLL_API(comm, nbc_module, iallreduce); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ialltoall); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ialltoallv); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ialltoallw); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ibarrier); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ibcast); + NBC_UNINSTALL_COLL_API(comm, nbc_module, igather); + NBC_UNINSTALL_COLL_API(comm, nbc_module, igatherv); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ireduce); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ireduce_scatter); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ireduce_scatter_block); + NBC_UNINSTALL_COLL_API(comm, nbc_module, iscatter); + NBC_UNINSTALL_COLL_API(comm, nbc_module, iscatterv); + + NBC_UNINSTALL_COLL_API(comm, nbc_module, allgather_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, allgatherv_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, allreduce_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, alltoall_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, alltoallv_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, alltoallw_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, barrier_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, bcast_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, gather_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, gatherv_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, reduce_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, reduce_scatter_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, reduce_scatter_block_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, scatter_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, scatterv_init); + + if (!OMPI_COMM_IS_INTER(comm)) { + NBC_UNINSTALL_COLL_API(comm, nbc_module, exscan_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, iexscan); + NBC_UNINSTALL_COLL_API(comm, nbc_module, scan_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, iscan); + + NBC_UNINSTALL_COLL_API(comm, nbc_module, ineighbor_allgather); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ineighbor_allgatherv); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ineighbor_alltoall); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ineighbor_alltoallv); + NBC_UNINSTALL_COLL_API(comm, nbc_module, ineighbor_alltoallw); + + NBC_UNINSTALL_COLL_API(comm, nbc_module, neighbor_allgather_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, neighbor_allgatherv_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, neighbor_alltoall_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, neighbor_alltoallv_init); + NBC_UNINSTALL_COLL_API(comm, nbc_module, neighbor_alltoallw_init); + } /* All done */ + return OMPI_SUCCESS; +} int ompi_coll_libnbc_progress(void) diff --git a/ompi/mca/coll/monitoring/coll_monitoring.h b/ompi/mca/coll/monitoring/coll_monitoring.h index 3f741970907..49d9d10b865 100644 --- a/ompi/mca/coll/monitoring/coll_monitoring.h +++ b/ompi/mca/coll/monitoring/coll_monitoring.h @@ -4,6 +4,7 @@ * and Technology (RIST). All rights reserved. * Copyright (c) 2017 Amazon.com, Inc. or its affiliates. All Rights * reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -36,7 +37,6 @@ struct mca_coll_monitoring_module_t { mca_coll_base_module_t super; mca_coll_base_comm_coll_t real; mca_monitoring_coll_data_t*data; - opal_atomic_int32_t is_initialized; }; typedef struct mca_coll_monitoring_module_t mca_coll_monitoring_module_t; OMPI_DECLSPEC OBJ_CLASS_DECLARATION(mca_coll_monitoring_module_t); diff --git a/ompi/mca/coll/monitoring/coll_monitoring_component.c b/ompi/mca/coll/monitoring/coll_monitoring_component.c index b23c3308ca2..cc1f56293d6 100644 --- a/ompi/mca/coll/monitoring/coll_monitoring_component.c +++ b/ompi/mca/coll/monitoring/coll_monitoring_component.c @@ -7,6 +7,7 @@ * reserved. * Copyright (c) 2019 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -21,49 +22,43 @@ #include "ompi/mca/coll/coll.h" #include "opal/mca/base/mca_base_component_repository.h" -#define MONITORING_SAVE_PREV_COLL_API(__module, __comm, __api) \ - do { \ - if( NULL != __comm->c_coll->coll_ ## __api ## _module ) { \ - __module->real.coll_ ## __api = __comm->c_coll->coll_ ## __api; \ - __module->real.coll_ ## __api ## _module = __comm->c_coll->coll_ ## __api ## _module; \ - OBJ_RETAIN(__module->real.coll_ ## __api ## _module); \ - } else { \ - /* If no function previously provided, do not monitor */ \ - __module->super.coll_ ## __api = NULL; \ - OPAL_MONITORING_PRINT_WARN("COMM \"%s\": No monitoring available for " \ - "coll_" # __api, __comm->c_name); \ - } \ - if( NULL != __comm->c_coll->coll_i ## __api ## _module ) { \ - __module->real.coll_i ## __api = __comm->c_coll->coll_i ## __api; \ - __module->real.coll_i ## __api ## _module = __comm->c_coll->coll_i ## __api ## _module; \ - OBJ_RETAIN(__module->real.coll_i ## __api ## _module); \ - } else { \ - /* If no function previously provided, do not monitor */ \ - __module->super.coll_i ## __api = NULL; \ - OPAL_MONITORING_PRINT_WARN("COMM \"%s\": No monitoring available for " \ - "coll_i" # __api, __comm->c_name); \ - } \ - } while(0) - -#define MONITORING_RELEASE_PREV_COLL_API(__module, __comm, __api) \ - do { \ - if( NULL != __module->real.coll_ ## __api ## _module ) { \ - if( NULL != __module->real.coll_ ## __api ## _module->coll_module_disable ) { \ - __module->real.coll_ ## __api ## _module->coll_module_disable(__module->real.coll_ ## __api ## _module, __comm); \ - } \ - OBJ_RELEASE(__module->real.coll_ ## __api ## _module); \ - __module->real.coll_ ## __api = NULL; \ - __module->real.coll_ ## __api ## _module = NULL; \ - } \ - if( NULL != __module->real.coll_i ## __api ## _module ) { \ - if( NULL != __module->real.coll_i ## __api ## _module->coll_module_disable ) { \ - __module->real.coll_i ## __api ## _module->coll_module_disable(__module->real.coll_i ## __api ## _module, __comm); \ - } \ - OBJ_RELEASE(__module->real.coll_i ## __api ## _module); \ - __module->real.coll_i ## __api = NULL; \ - __module->real.coll_i ## __api ## _module = NULL; \ - } \ - } while(0) +#define MONITORING_INSTALL_COLL_API(__module, __comm, __api) \ + do \ + { \ + if ((NULL != __comm->c_coll->coll_##__api##_module) && \ + (NULL != __module->super.coll_##__api)) \ + { \ + /* save the current selected collective */ \ + MCA_COLL_SAVE_API(__comm, __api, __module->real.coll_##__api, __module->real.coll_##__api##_module, "monitoring"); \ + /* install our own */ \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "monitoring"); \ + } \ + if ((NULL != __comm->c_coll->coll_i##__api##_module) && \ + (NULL != __module->super.coll_i##__api)) \ + { \ + /* save the current selected collective */ \ + MCA_COLL_SAVE_API(__comm, i##__api, __module->real.coll_i##__api, __module->real.coll_i##__api##_module, "monitoring"); \ + /* install our own */ \ + MCA_COLL_INSTALL_API(__comm, i##__api, __module->super.coll_i##__api, &__module->super, "monitoring"); \ + } \ + } while (0) + +#define MONITORING_UNINSTALL_COLL_API(__module, __comm, __api) \ + do \ + { \ + if (&__module->super == __comm->c_coll->coll_##__api##_module) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->real.coll_##__api, __module->real.coll_##__api##_module, "monitoring"); \ + __module->real.coll_##__api = NULL; \ + __module->real.coll_##__api##_module = NULL; \ + } \ + if (&__module->super == __comm->c_coll->coll_i##__api##_module) \ + { \ + MCA_COLL_INSTALL_API(__comm, i##__api, __module->real.coll_i##__api, __module->real.coll_i##__api##_module, "monitoring"); \ + __module->real.coll_i##__api = NULL; \ + __module->real.coll_i##__api##_module = NULL; \ + } \ + } while (0) #define MONITORING_SET_FULL_PREV_COLL_API(m, c, operation) \ do { \ @@ -91,11 +86,11 @@ operation(m, c, neighbor_alltoallw); \ } while(0) -#define MONITORING_SAVE_FULL_PREV_COLL_API(m, c) \ - MONITORING_SET_FULL_PREV_COLL_API((m), (c), MONITORING_SAVE_PREV_COLL_API) +#define MONITORING_SAVE_FULL_PREV_COLL_API(m, c) \ + MONITORING_SET_FULL_PREV_COLL_API((m), (c), MONITORING_INSTALL_COLL_API) -#define MONITORING_RELEASE_FULL_PREV_COLL_API(m, c) \ - MONITORING_SET_FULL_PREV_COLL_API((m), (c), MONITORING_RELEASE_PREV_COLL_API) +#define MONITORING_RELEASE_FULL_PREV_COLL_API(m, c) \ + MONITORING_SET_FULL_PREV_COLL_API((m), (c), MONITORING_UNINSTALL_COLL_API) static int mca_coll_monitoring_component_open(void) { @@ -125,11 +120,11 @@ static int mca_coll_monitoring_module_enable(mca_coll_base_module_t*module, struct ompi_communicator_t*comm) { mca_coll_monitoring_module_t*monitoring_module = (mca_coll_monitoring_module_t*) module; - if( 1 == opal_atomic_add_fetch_32(&monitoring_module->is_initialized, 1) ) { - MONITORING_SAVE_FULL_PREV_COLL_API(monitoring_module, comm); - monitoring_module->data = mca_common_monitoring_coll_new(comm); - OPAL_MONITORING_PRINT_INFO("coll_module_enabled"); - } + + MONITORING_SAVE_FULL_PREV_COLL_API(monitoring_module, comm); + monitoring_module->data = mca_common_monitoring_coll_new(comm); + OPAL_MONITORING_PRINT_INFO("coll_module_enabled"); + return OMPI_SUCCESS; } @@ -137,12 +132,12 @@ static int mca_coll_monitoring_module_disable(mca_coll_base_module_t*module, struct ompi_communicator_t*comm) { mca_coll_monitoring_module_t*monitoring_module = (mca_coll_monitoring_module_t*) module; - if( 0 == opal_atomic_sub_fetch_32(&monitoring_module->is_initialized, 1) ) { - MONITORING_RELEASE_FULL_PREV_COLL_API(monitoring_module, comm); - mca_common_monitoring_coll_release(monitoring_module->data); - monitoring_module->data = NULL; - OPAL_MONITORING_PRINT_INFO("coll_module_disabled"); - } + + MONITORING_RELEASE_FULL_PREV_COLL_API(monitoring_module, comm); + mca_common_monitoring_coll_release(monitoring_module->data); + monitoring_module->data = NULL; + OPAL_MONITORING_PRINT_INFO("coll_module_disabled"); + return OMPI_SUCCESS; } @@ -208,9 +203,6 @@ mca_coll_monitoring_component_query(struct ompi_communicator_t*comm, int*priorit monitoring_module->super.coll_ineighbor_alltoallv = mca_coll_monitoring_ineighbor_alltoallv; monitoring_module->super.coll_ineighbor_alltoallw = mca_coll_monitoring_ineighbor_alltoallw; - /* Initialization flag */ - monitoring_module->is_initialized = 0; - *priority = mca_coll_monitoring_component.priority; return &(monitoring_module->super); diff --git a/ompi/mca/coll/portals4/coll_portals4_component.c b/ompi/mca/coll/portals4/coll_portals4_component.c index 6b0a8573371..11b476ceb20 100644 --- a/ompi/mca/coll/portals4/coll_portals4_component.c +++ b/ompi/mca/coll/portals4/coll_portals4_component.c @@ -14,6 +14,7 @@ * Copyright (c) 2015 Los Alamos National Security, LLC. All rights * reserved. * Copyright (c) 2015 Bull SAS. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -127,20 +128,37 @@ ptl_datatype_t ompi_coll_portals4_atomic_datatype [OMPI_DATATYPE_MPI_MAX_PREDEFI }; - -#define PORTALS4_SAVE_PREV_COLL_API(__module, __comm, __api) \ - do { \ - __module->previous_ ## __api = __comm->c_coll->coll_ ## __api; \ - __module->previous_ ## __api ## _module = __comm->c_coll->coll_ ## __api ## _module; \ - if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \ - opal_output_verbose(1, ompi_coll_base_framework.framework_output, \ - "(%d/%s): no underlying " # __api"; disqualifying myself", \ - ompi_comm_get_local_cid(__comm), __comm->c_name); \ - return OMPI_ERROR; \ - } \ - OBJ_RETAIN(__module->previous_ ## __api ## _module); \ - } while(0) - +#define PORTALS4_INSTALL_COLL_API(__module, __comm, __api) \ + do \ + { \ + if (!comm->c_coll->coll_##__api || !comm->c_coll->coll_##__api##_module) \ + { \ + opal_output_verbose(1, ompi_coll_base_framework.framework_output, \ + "(%d/%s): no underlying " #__api "; disqualifying myself", \ + __comm->c_contextid, __comm->c_name); \ + __module->previous_##__api = NULL; \ + __module->previous_##__api##_module = NULL; \ + } \ + else \ + { \ + /* save the current selected collective */ \ + MCA_COLL_SAVE_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "portals"); \ + /* install our own */ \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll##__api, &__module->super, "portals"); \ + } \ + } while (0) + +#define PORTALS4_UNINSTALL_COLL_API(__module, __comm, __api) \ + do \ + { \ + if ((&__module->super == comm->c_coll->coll_##__api##_module) && \ + (NULL != __module->previous_##__api##_module)) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "portals"); \ + __module->previous_##__api = NULL; \ + __module->previous_##__api##_module = NULL; \ + } \ + } while (0) const char *mca_coll_portals4_component_version_string = "Open MPI Portals 4 collective MCA component version " OMPI_VERSION; @@ -158,6 +176,8 @@ static mca_coll_base_module_t* portals4_comm_query(struct ompi_communicator_t *c int *priority); static int portals4_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm); +static int portals4_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); static int portals4_progress(void); @@ -618,6 +638,7 @@ portals4_comm_query(struct ompi_communicator_t *comm, *priority = mca_coll_portals4_priority; portals4_module->coll_count = 0; portals4_module->super.coll_module_enable = portals4_module_enable; + portals4_module->super.coll_module_disable = portals4_module_disable; portals4_module->super.coll_barrier = ompi_coll_portals4_barrier_intra; portals4_module->super.coll_ibarrier = ompi_coll_portals4_ibarrier_intra; @@ -653,14 +674,45 @@ portals4_module_enable(mca_coll_base_module_t *module, { mca_coll_portals4_module_t *portals4_module = (mca_coll_portals4_module_t*) module; - PORTALS4_SAVE_PREV_COLL_API(portals4_module, comm, allreduce); - PORTALS4_SAVE_PREV_COLL_API(portals4_module, comm, iallreduce); - PORTALS4_SAVE_PREV_COLL_API(portals4_module, comm, reduce); - PORTALS4_SAVE_PREV_COLL_API(portals4_module, comm, ireduce); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, iallreduce); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, allreduce); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, ireduce); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, reduce); + + PORTALS4_INSTALL_COLL_API(portals4_module, comm, barrier); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, ibarrier); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, gather); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, igather); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, scatter); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, iscatter); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, bcast); + PORTALS4_INSTALL_COLL_API(portals4_module, comm, ibcast); return OMPI_SUCCESS; } +static int +portals4_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_portals4_module_t *portals4_module = (mca_coll_portals4_module_t *)module; + + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, allreduce); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, iallreduce); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, reduce); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, ireduce); + + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, barrier); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, ibarrier); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, gather); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, igather); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, scatter); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, iscatter); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, bcast); + PORTALS4_UNINSTALL_COLL_API(portals4_module, comm, ibcast); + + return OMPI_SUCCESS; +} #if OPAL_ENABLE_DEBUG /* These string maps are only used for debugging output. * They will be compiled-out when OPAL is configured diff --git a/ompi/mca/coll/self/coll_self.h b/ompi/mca/coll/self/coll_self.h index 1dad4536d67..1585e3bc8a1 100644 --- a/ompi/mca/coll/self/coll_self.h +++ b/ompi/mca/coll/self/coll_self.h @@ -12,6 +12,7 @@ * Copyright (c) 2008 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2015 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -50,9 +51,6 @@ int mca_coll_self_init_query(bool enable_progress_threads, mca_coll_base_module_t * mca_coll_self_comm_query(struct ompi_communicator_t *comm, int *priority); -int mca_coll_self_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - int mca_coll_self_allgather_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, diff --git a/ompi/mca/coll/self/coll_self_module.c b/ompi/mca/coll/self/coll_self_module.c index d782b998165..ec706b52613 100644 --- a/ompi/mca/coll/self/coll_self_module.c +++ b/ompi/mca/coll/self/coll_self_module.c @@ -10,6 +10,7 @@ * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. * Copyright (c) 2017 IBM Corporation. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -29,6 +30,12 @@ #include "ompi/mca/coll/base/coll_base_functions.h" #include "coll_self.h" +static int +mca_coll_self_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); +static int +mca_coll_self_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); /* * Initial query function that is invoked during MPI_INIT, allowing @@ -63,22 +70,7 @@ mca_coll_self_comm_query(struct ompi_communicator_t *comm, if (NULL == module) return NULL; module->super.coll_module_enable = mca_coll_self_module_enable; - module->super.coll_allgather = mca_coll_self_allgather_intra; - module->super.coll_allgatherv = mca_coll_self_allgatherv_intra; - module->super.coll_allreduce = mca_coll_self_allreduce_intra; - module->super.coll_alltoall = mca_coll_self_alltoall_intra; - module->super.coll_alltoallv = mca_coll_self_alltoallv_intra; - module->super.coll_alltoallw = mca_coll_self_alltoallw_intra; - module->super.coll_barrier = mca_coll_self_barrier_intra; - module->super.coll_bcast = mca_coll_self_bcast_intra; - module->super.coll_exscan = mca_coll_self_exscan_intra; - module->super.coll_gather = mca_coll_self_gather_intra; - module->super.coll_gatherv = mca_coll_self_gatherv_intra; - module->super.coll_reduce = mca_coll_self_reduce_intra; - module->super.coll_reduce_scatter = mca_coll_self_reduce_scatter_intra; - module->super.coll_scan = mca_coll_self_scan_intra; - module->super.coll_scatter = mca_coll_self_scatter_intra; - module->super.coll_scatterv = mca_coll_self_scatterv_intra; + module->super.coll_module_disable = mca_coll_self_module_disable; module->super.coll_reduce_local = mca_coll_base_reduce_local; @@ -88,13 +80,73 @@ mca_coll_self_comm_query(struct ompi_communicator_t *comm, return NULL; } +#define SELF_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_self_##__api##_intra, &__module->super, "self"); \ + } while (0) + +#define SELF_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "self"); \ + } \ + } while (0) /* * Init module on the communicator */ -int +static int mca_coll_self_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm) { + mca_coll_self_module_t *sm_module = (mca_coll_self_module_t*)module; + + SELF_INSTALL_COLL_API(comm, sm_module, allgather); + SELF_INSTALL_COLL_API(comm, sm_module, allgatherv); + SELF_INSTALL_COLL_API(comm, sm_module, allreduce); + SELF_INSTALL_COLL_API(comm, sm_module, alltoall); + SELF_INSTALL_COLL_API(comm, sm_module, alltoallv); + SELF_INSTALL_COLL_API(comm, sm_module, alltoallw); + SELF_INSTALL_COLL_API(comm, sm_module, barrier); + SELF_INSTALL_COLL_API(comm, sm_module, bcast); + SELF_INSTALL_COLL_API(comm, sm_module, exscan); + SELF_INSTALL_COLL_API(comm, sm_module, gather); + SELF_INSTALL_COLL_API(comm, sm_module, gatherv); + SELF_INSTALL_COLL_API(comm, sm_module, reduce); + SELF_INSTALL_COLL_API(comm, sm_module, reduce_scatter); + SELF_INSTALL_COLL_API(comm, sm_module, scan); + SELF_INSTALL_COLL_API(comm, sm_module, scatter); + SELF_INSTALL_COLL_API(comm, sm_module, scatterv); + + MCA_COLL_INSTALL_API(comm, reduce_local, mca_coll_base_reduce_local, module, "self"); + return OMPI_SUCCESS; +} + +static int +mca_coll_self_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_self_module_t *sm_module = (mca_coll_self_module_t *)module; + + SELF_UNINSTALL_COLL_API(comm, sm_module, allgather); + SELF_UNINSTALL_COLL_API(comm, sm_module, allgatherv); + SELF_UNINSTALL_COLL_API(comm, sm_module, allreduce); + SELF_UNINSTALL_COLL_API(comm, sm_module, alltoall); + SELF_UNINSTALL_COLL_API(comm, sm_module, alltoallv); + SELF_UNINSTALL_COLL_API(comm, sm_module, alltoallw); + SELF_UNINSTALL_COLL_API(comm, sm_module, barrier); + SELF_UNINSTALL_COLL_API(comm, sm_module, bcast); + SELF_UNINSTALL_COLL_API(comm, sm_module, exscan); + SELF_UNINSTALL_COLL_API(comm, sm_module, gather); + SELF_UNINSTALL_COLL_API(comm, sm_module, gatherv); + SELF_UNINSTALL_COLL_API(comm, sm_module, reduce); + SELF_UNINSTALL_COLL_API(comm, sm_module, reduce_scatter); + SELF_UNINSTALL_COLL_API(comm, sm_module, scan); + SELF_UNINSTALL_COLL_API(comm, sm_module, scatter); + SELF_UNINSTALL_COLL_API(comm, sm_module, scatterv); + return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/sync/Makefile.am b/ompi/mca/coll/sync/Makefile.am index 2f75cd2dfa5..dad4f8f7138 100644 --- a/ompi/mca/coll/sync/Makefile.am +++ b/ompi/mca/coll/sync/Makefile.am @@ -12,6 +12,7 @@ # Copyright (c) 2009 Cisco Systems, Inc. All rights reserved. # Copyright (c) 2016 Intel, Inc. All rights reserved # Copyright (c) 2017 IBM Corporation. All rights reserved. +# Copyright (c) 2024 NVIDIA Corporation. All rights reserved. # $COPYRIGHT$ # # Additional copyrights may follow @@ -19,8 +20,6 @@ # $HEADER$ # -dist_ompidata_DATA = help-coll-sync.txt - sources = \ coll_sync.h \ coll_sync_component.c \ diff --git a/ompi/mca/coll/sync/coll_sync.h b/ompi/mca/coll/sync/coll_sync.h index 76913b9615c..b6617f9ef85 100644 --- a/ompi/mca/coll/sync/coll_sync.h +++ b/ompi/mca/coll/sync/coll_sync.h @@ -10,6 +10,7 @@ * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. * Copyright (c) 2008-2009 Cisco Systems, Inc. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -43,9 +44,6 @@ mca_coll_base_module_t *mca_coll_sync_comm_query(struct ompi_communicator_t *comm, int *priority); -int mca_coll_sync_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm); - int mca_coll_sync_barrier(struct ompi_communicator_t *comm, mca_coll_base_module_t *module); diff --git a/ompi/mca/coll/sync/coll_sync_module.c b/ompi/mca/coll/sync/coll_sync_module.c index 128cb8c3fa9..65b32f1ba16 100644 --- a/ompi/mca/coll/sync/coll_sync_module.c +++ b/ompi/mca/coll/sync/coll_sync_module.c @@ -13,6 +13,7 @@ * Copyright (c) 2016 Research Organization for Information Science * and Technology (RIST). All rights reserved. * Copyright (c) 2018-2019 Intel, Inc. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -40,6 +41,12 @@ #include "ompi/mca/coll/base/base.h" #include "coll_sync.h" +static int +mca_coll_sync_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); +static int +mca_coll_sync_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); static void mca_coll_sync_module_construct(mca_coll_sync_module_t *module) { @@ -111,16 +118,10 @@ mca_coll_sync_comm_query(struct ompi_communicator_t *comm, /* Choose whether to use [intra|inter] */ sync_module->super.coll_module_enable = mca_coll_sync_module_enable; + sync_module->super.coll_module_disable = mca_coll_sync_module_disable; /* The "all" versions are already synchronous. So no need for an additional barrier there. */ - sync_module->super.coll_allgather = NULL; - sync_module->super.coll_allgatherv = NULL; - sync_module->super.coll_allreduce = NULL; - sync_module->super.coll_alltoall = NULL; - sync_module->super.coll_alltoallv = NULL; - sync_module->super.coll_alltoallw = NULL; - sync_module->super.coll_barrier = NULL; sync_module->super.coll_bcast = mca_coll_sync_bcast; sync_module->super.coll_exscan = mca_coll_sync_exscan; sync_module->super.coll_gather = mca_coll_sync_gather; @@ -134,47 +135,78 @@ mca_coll_sync_comm_query(struct ompi_communicator_t *comm, return &(sync_module->super); } +#define SYNC_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + /* save the current selected collective */ \ + MCA_COLL_SAVE_API(__comm, __api, __module->c_coll.coll_##__api, __module->c_coll.coll_##__api##_module, "sync"); \ + /* install our own */ \ + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_sync_##__api, &__module->super, "sync"); \ + } while (0) + +#define SYNC_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_sync_##__api, __module->c_coll.coll_##__api##_module, "sync"); \ + __module->c_coll.coll_##__api = NULL; \ + __module->c_coll.coll_##__api##_module = NULL; \ + } \ + } while (0) /* * Init module on the communicator */ -int mca_coll_sync_module_enable(mca_coll_base_module_t *module, - struct ompi_communicator_t *comm) +static int +mca_coll_sync_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) { - bool good = true; - char *msg = NULL; - mca_coll_sync_module_t *s = (mca_coll_sync_module_t*) module; + mca_coll_sync_module_t *sync_module = (mca_coll_sync_module_t*) module; - /* Save the prior layer of coll functions */ - s->c_coll = *comm->c_coll; - -#define CHECK_AND_RETAIN(name) \ - if (NULL == s->c_coll.coll_ ## name ## _module) { \ - good = false; \ - msg = #name; \ - } else if (good) { \ - OBJ_RETAIN(s->c_coll.coll_ ## name ## _module); \ - } + /* The "all" versions are already synchronous. So no need for an + additional barrier there. */ + SYNC_INSTALL_COLL_API(comm, sync_module, bcast); + SYNC_INSTALL_COLL_API(comm, sync_module, gather); + SYNC_INSTALL_COLL_API(comm, sync_module, gatherv); + SYNC_INSTALL_COLL_API(comm, sync_module, reduce); + SYNC_INSTALL_COLL_API(comm, sync_module, reduce_scatter); + SYNC_INSTALL_COLL_API(comm, sync_module, scatter); + SYNC_INSTALL_COLL_API(comm, sync_module, scatterv); - CHECK_AND_RETAIN(bcast); - CHECK_AND_RETAIN(gather); - CHECK_AND_RETAIN(gatherv); - CHECK_AND_RETAIN(reduce); - CHECK_AND_RETAIN(reduce_scatter); - CHECK_AND_RETAIN(scatter); - CHECK_AND_RETAIN(scatterv); if (!OMPI_COMM_IS_INTER(comm)) { /* MPI does not define scan/exscan on intercommunicators */ - CHECK_AND_RETAIN(exscan); - CHECK_AND_RETAIN(scan); + SYNC_INSTALL_COLL_API(comm, sync_module, scan); + SYNC_INSTALL_COLL_API(comm, sync_module, exscan); } - /* All done */ - if (good) { - return OMPI_SUCCESS; + return OMPI_SUCCESS; +} + +static int +mca_coll_sync_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_sync_module_t *sync_module = (mca_coll_sync_module_t*) module; + + /* Save the prior layer of coll functions */ + sync_module->c_coll = *comm->c_coll; + + /* The "all" versions are already synchronous. So no need for an + additional barrier there. */ + SYNC_UNINSTALL_COLL_API(comm, sync_module, bcast); + SYNC_UNINSTALL_COLL_API(comm, sync_module, gather); + SYNC_UNINSTALL_COLL_API(comm, sync_module, gatherv); + SYNC_UNINSTALL_COLL_API(comm, sync_module, reduce); + SYNC_UNINSTALL_COLL_API(comm, sync_module, reduce_scatter); + SYNC_UNINSTALL_COLL_API(comm, sync_module, scatter); + SYNC_UNINSTALL_COLL_API(comm, sync_module, scatterv); + + if (!OMPI_COMM_IS_INTER(comm)) { + /* MPI does not define scan/exscan on intercommunicators */ + SYNC_INSTALL_COLL_API(comm, sync_module, scan); + SYNC_INSTALL_COLL_API(comm, sync_module, exscan); } - opal_show_help("help-coll-sync.txt", "missing collective", true, - ompi_process_info.nodename, - mca_coll_sync_component.priority, msg); - return OMPI_ERR_NOT_FOUND; + + return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/sync/help-coll-sync.txt b/ompi/mca/coll/sync/help-coll-sync.txt deleted file mode 100644 index 4a5c871207e..00000000000 --- a/ompi/mca/coll/sync/help-coll-sync.txt +++ /dev/null @@ -1,22 +0,0 @@ -# -*- text -*- -# -# Copyright (c) 2009 Cisco Systems, Inc. All rights reserved. -# $COPYRIGHT$ -# -# Additional copyrights may follow -# -# $HEADER$ -# -# This is the US/English general help file for Open MPI's sync -# collective component. -# -[missing collective] -The sync collective component in Open MPI was activated on a -communicator where it did not find an underlying collective operation -defined. This usually means that the sync collective module's -priority was not set high enough. Please try increasing sync's -priority. - - Local host: %s - Sync coll module priority: %d - First discovered missing collective: %s diff --git a/ompi/mca/coll/tuned/coll_tuned_module.c b/ompi/mca/coll/tuned/coll_tuned_module.c index f1be43440cc..f83b3ecd9ea 100644 --- a/ompi/mca/coll/tuned/coll_tuned_module.c +++ b/ompi/mca/coll/tuned/coll_tuned_module.c @@ -13,6 +13,7 @@ * Copyright (c) 2016 Intel, Inc. All rights reserved. * Copyright (c) 2018 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -36,6 +37,8 @@ static int tuned_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm); +static int tuned_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm); /* * Initial query function that is invoked during MPI_INIT, allowing * this component to disqualify itself if it doesn't support the @@ -89,6 +92,7 @@ ompi_coll_tuned_comm_query(struct ompi_communicator_t *comm, int *priority) * but this would probably add an extra if and funct call to the path */ tuned_module->super.coll_module_enable = tuned_module_enable; + tuned_module->super.coll_module_disable = tuned_module_disable; /* By default stick with the fixed version of the tuned collectives. Later on, * when the module get enabled, set the correct version based on the availability @@ -99,18 +103,13 @@ ompi_coll_tuned_comm_query(struct ompi_communicator_t *comm, int *priority) tuned_module->super.coll_allreduce = ompi_coll_tuned_allreduce_intra_dec_fixed; tuned_module->super.coll_alltoall = ompi_coll_tuned_alltoall_intra_dec_fixed; tuned_module->super.coll_alltoallv = ompi_coll_tuned_alltoallv_intra_dec_fixed; - tuned_module->super.coll_alltoallw = NULL; tuned_module->super.coll_barrier = ompi_coll_tuned_barrier_intra_dec_fixed; tuned_module->super.coll_bcast = ompi_coll_tuned_bcast_intra_dec_fixed; - tuned_module->super.coll_exscan = NULL; tuned_module->super.coll_gather = ompi_coll_tuned_gather_intra_dec_fixed; - tuned_module->super.coll_gatherv = NULL; tuned_module->super.coll_reduce = ompi_coll_tuned_reduce_intra_dec_fixed; tuned_module->super.coll_reduce_scatter = ompi_coll_tuned_reduce_scatter_intra_dec_fixed; tuned_module->super.coll_reduce_scatter_block = ompi_coll_tuned_reduce_scatter_block_intra_dec_fixed; - tuned_module->super.coll_scan = NULL; tuned_module->super.coll_scatter = ompi_coll_tuned_scatter_intra_dec_fixed; - tuned_module->super.coll_scatterv = NULL; return &(tuned_module->super); } @@ -148,8 +147,26 @@ ompi_coll_tuned_forced_getvalues( enum COLLTYPE type, return (MPI_SUCCESS); } -#define COLL_TUNED_EXECUTE_IF_DYNAMIC(TMOD, TYPE, EXECUTE) \ +#define TUNED_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__module->super.coll_##__api) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "tuned"); \ + } \ + } while (0) + +#define TUNED_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "tuned"); \ + } \ + } while (0) + +#define COLL_TUNED_EXECUTE_IF_DYNAMIC(TMOD, TYPE, EXECUTE) \ + do { \ int need_dynamic_decision = 0; \ ompi_coll_tuned_forced_getvalues( (TYPE), &((TMOD)->user_forced[(TYPE)]) ); \ (TMOD)->com_rules[(TYPE)] = NULL; \ @@ -168,7 +185,7 @@ ompi_coll_tuned_forced_getvalues( enum COLLTYPE type, OPAL_OUTPUT((ompi_coll_tuned_stream,"coll:tuned: enable dynamic selection for "#TYPE)); \ EXECUTE; \ } \ - } + } while(0) /* * Init module on the communicator @@ -249,6 +266,23 @@ tuned_module_enable( mca_coll_base_module_t *module, COLL_TUNED_EXECUTE_IF_DYNAMIC(tuned_module, SCATTERV, tuned_module->super.coll_scatterv = NULL); } + TUNED_INSTALL_COLL_API(comm, tuned_module, allgather); + TUNED_INSTALL_COLL_API(comm, tuned_module, allgatherv); + TUNED_INSTALL_COLL_API(comm, tuned_module, allreduce); + TUNED_INSTALL_COLL_API(comm, tuned_module, alltoall); + TUNED_INSTALL_COLL_API(comm, tuned_module, alltoallv); + TUNED_INSTALL_COLL_API(comm, tuned_module, alltoallw); + TUNED_INSTALL_COLL_API(comm, tuned_module, barrier); + TUNED_INSTALL_COLL_API(comm, tuned_module, bcast); + TUNED_INSTALL_COLL_API(comm, tuned_module, exscan); + TUNED_INSTALL_COLL_API(comm, tuned_module, gather); + TUNED_INSTALL_COLL_API(comm, tuned_module, gatherv); + TUNED_INSTALL_COLL_API(comm, tuned_module, reduce); + TUNED_INSTALL_COLL_API(comm, tuned_module, reduce_scatter); + TUNED_INSTALL_COLL_API(comm, tuned_module, reduce_scatter_block); + TUNED_INSTALL_COLL_API(comm, tuned_module, scan); + TUNED_INSTALL_COLL_API(comm, tuned_module, scatter); + TUNED_INSTALL_COLL_API(comm, tuned_module, scatterv); /* general n fan out tree */ data->cached_ntree = NULL; @@ -273,3 +307,30 @@ tuned_module_enable( mca_coll_base_module_t *module, OPAL_OUTPUT((ompi_coll_tuned_stream,"coll:tuned:module_init Tuned is in use")); return OMPI_SUCCESS; } + +static int +tuned_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_tuned_module_t *tuned_module = (mca_coll_tuned_module_t *) module; + + TUNED_UNINSTALL_COLL_API(comm, tuned_module, allgather); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, allgatherv); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, allreduce); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, alltoall); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, alltoallv); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, alltoallw); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, barrier); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, bcast); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, exscan); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, gather); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, gatherv); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, reduce); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, reduce_scatter); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, reduce_scatter_block); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, scan); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, scatter); + TUNED_UNINSTALL_COLL_API(comm, tuned_module, scatterv); + + return OMPI_SUCCESS; +} diff --git a/ompi/mca/coll/ucc/coll_ucc_module.c b/ompi/mca/coll/ucc/coll_ucc_module.c index 32861d315de..35bceac3f7e 100644 --- a/ompi/mca/coll/ucc/coll_ucc_module.c +++ b/ompi/mca/coll/ucc/coll_ucc_module.c @@ -1,8 +1,8 @@ /** - * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 Amazon.com, Inc. or its affiliates. * All Rights reserved. - * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2022-2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -16,8 +16,6 @@ #include "ompi/mca/coll/base/coll_tags.h" #include "ompi/mca/pml/pml.h" -#define OBJ_RELEASE_IF_NOT_NULL( obj ) if( NULL != (obj) ) OBJ_RELEASE( obj ); - static int ucc_comm_attr_keyval; /* * Initial query function that is invoked during MPI_INIT, allowing @@ -107,80 +105,9 @@ static void mca_coll_ucc_module_destruct(mca_coll_ucc_module_t *ucc_module) UCC_ERROR("ucc ompi_attr_free_keyval failed"); } } - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_allreduce_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallreduce_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_barrier_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ibarrier_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_bcast_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ibcast_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_alltoall_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ialltoall_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_alltoallv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ialltoallv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_allgather_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallgather_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_allgatherv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallgatherv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_gather_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_igather_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_gatherv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_igatherv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_scatter_block_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_scatter_block_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_scatterv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iscatterv_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_scatter_module); - OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iscatter_module); mca_coll_ucc_module_clear(ucc_module); } -#define SAVE_PREV_COLL_API(__api) do { \ - ucc_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \ - ucc_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \ - if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \ - return OMPI_ERROR; \ - } \ - OBJ_RETAIN(ucc_module->previous_ ## __api ## _module); \ - } while(0) - -static int mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module) -{ - ompi_communicator_t *comm = ucc_module->comm; - SAVE_PREV_COLL_API(allreduce); - SAVE_PREV_COLL_API(iallreduce); - SAVE_PREV_COLL_API(barrier); - SAVE_PREV_COLL_API(ibarrier); - SAVE_PREV_COLL_API(bcast); - SAVE_PREV_COLL_API(ibcast); - SAVE_PREV_COLL_API(alltoall); - SAVE_PREV_COLL_API(ialltoall); - SAVE_PREV_COLL_API(alltoallv); - SAVE_PREV_COLL_API(ialltoallv); - SAVE_PREV_COLL_API(allgather); - SAVE_PREV_COLL_API(iallgather); - SAVE_PREV_COLL_API(allgatherv); - SAVE_PREV_COLL_API(iallgatherv); - SAVE_PREV_COLL_API(reduce); - SAVE_PREV_COLL_API(ireduce); - SAVE_PREV_COLL_API(gather); - SAVE_PREV_COLL_API(igather); - SAVE_PREV_COLL_API(gatherv); - SAVE_PREV_COLL_API(igatherv); - SAVE_PREV_COLL_API(reduce_scatter_block); - SAVE_PREV_COLL_API(ireduce_scatter_block); - SAVE_PREV_COLL_API(reduce_scatter); - SAVE_PREV_COLL_API(ireduce_scatter); - SAVE_PREV_COLL_API(scatterv); - SAVE_PREV_COLL_API(iscatterv); - SAVE_PREV_COLL_API(scatter); - SAVE_PREV_COLL_API(iscatter); - return OMPI_SUCCESS; -} - /* ** Communicator free callback */ @@ -440,6 +367,50 @@ static inline ucc_ep_map_t get_rank_map(struct ompi_communicator_t *comm) return map; } + +#define UCC_INSTALL_COLL_API(__comm, __ucc_module, __COLL, __api) \ + do \ + { \ + if ((mca_coll_ucc_component.ucc_lib_attr.coll_types & UCC_COLL_TYPE_##__COLL)) \ + { \ + if (mca_coll_ucc_component.cts_requested & UCC_COLL_TYPE_##__COLL) \ + { \ + MCA_COLL_SAVE_API(__comm, __api, (__ucc_module)->previous_##__api, (__ucc_module)->previous_##__api##_module, "ucc"); \ + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_ucc_##__api, &__ucc_module->super, "ucc"); \ + (__ucc_module)->super.coll_##__api = mca_coll_ucc_##__api; \ + } \ + if (mca_coll_ucc_component.nb_cts_requested & UCC_COLL_TYPE_##__COLL) \ + { \ + MCA_COLL_SAVE_API(__comm, i##__api, (__ucc_module)->previous_i##__api, (__ucc_module)->previous_i##__api##_module, "ucc"); \ + MCA_COLL_INSTALL_API(__comm, i##__api, mca_coll_ucc_i##__api, &__ucc_module->super, "ucc"); \ + (__ucc_module)->super.coll_i##__api = mca_coll_ucc_i##__api; \ + } \ + } \ + } while (0) + +static int mca_coll_ucc_replace_coll_handlers(mca_coll_ucc_module_t *ucc_module) +{ + ompi_communicator_t *comm = ucc_module->comm; + + UCC_INSTALL_COLL_API(comm, ucc_module, ALLREDUCE, allreduce); + UCC_INSTALL_COLL_API(comm, ucc_module, BARRIER, barrier); + UCC_INSTALL_COLL_API(comm, ucc_module, BCAST, bcast); + UCC_INSTALL_COLL_API(comm, ucc_module, ALLTOALL, alltoall); + UCC_INSTALL_COLL_API(comm, ucc_module, ALLTOALLV, alltoallv); + UCC_INSTALL_COLL_API(comm, ucc_module, ALLGATHER, allgather); + UCC_INSTALL_COLL_API(comm, ucc_module, ALLGATHERV, allgatherv); + UCC_INSTALL_COLL_API(comm, ucc_module, REDUCE, reduce); + + UCC_INSTALL_COLL_API(comm, ucc_module, GATHER, gather); + UCC_INSTALL_COLL_API(comm, ucc_module, GATHERV, gatherv); + UCC_INSTALL_COLL_API(comm, ucc_module, REDUCE_SCATTER_BLOCK, reduce_scatter_block); + UCC_INSTALL_COLL_API(comm, ucc_module, REDUCE_SCATTER, reduce_scatter); + UCC_INSTALL_COLL_API(comm, ucc_module, SCATTER, scatter); + UCC_INSTALL_COLL_API(comm, ucc_module, SCATTERV, scatterv); + + return OMPI_SUCCESS; +} + /* * Initialize module on the communicator */ @@ -470,11 +441,6 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module, (void*)comm, (long long unsigned)team_params.id, ompi_comm_size(comm)); - if (OMPI_SUCCESS != mca_coll_ucc_save_coll_handlers(ucc_module)){ - UCC_ERROR("mca_coll_ucc_save_coll_handlers failed"); - goto err; - } - if (UCC_OK != ucc_team_create_post(&cm->ucc_context, 1, &team_params, &ucc_module->ucc_team)) { UCC_ERROR("ucc_team_create_post failed"); @@ -489,12 +455,18 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module, goto err; } + if (OMPI_SUCCESS != mca_coll_ucc_replace_coll_handlers(ucc_module)) { + UCC_ERROR("mca_coll_ucc_replace_coll_handlers failed"); + goto err; + } + rc = ompi_attr_set_c(COMM_ATTR, comm, &comm->c_keyhash, ucc_comm_attr_keyval, (void *)module, false); if (OMPI_SUCCESS != rc) { UCC_ERROR("ucc ompi_attr_set_c failed"); goto err; } + return OMPI_SUCCESS; err: @@ -504,22 +476,53 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module, return OMPI_ERROR; } +#define UCC_UNINSTALL_COLL_API(__comm, __ucc_module, __api) \ + do \ + { \ + if (&(__ucc_module)->super == (__comm)->c_coll->coll_##__api##_module) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, (__ucc_module)->previous_##__api, (__ucc_module)->previous_##__api##_module, "ucc"); \ + (__ucc_module)->previous_##__api = NULL; \ + (__ucc_module)->previous_##__api##_module = NULL; \ + } \ + } while (0) + +/** + * The disable will be called once per collective module, in the reverse order + * in which enable has been called. This reverse order allows the module to properly + * unregister the collective function pointers they provide for the communicator. + */ +static int +mca_coll_ucc_module_disable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module; + UCC_UNINSTALL_COLL_API(comm, ucc_module, allreduce); + UCC_UNINSTALL_COLL_API(comm, ucc_module, iallreduce); + UCC_UNINSTALL_COLL_API(comm, ucc_module, barrier); + UCC_UNINSTALL_COLL_API(comm, ucc_module, ibarrier); + UCC_UNINSTALL_COLL_API(comm, ucc_module, bcast); + UCC_UNINSTALL_COLL_API(comm, ucc_module, ibcast); + UCC_UNINSTALL_COLL_API(comm, ucc_module, alltoall); + UCC_UNINSTALL_COLL_API(comm, ucc_module, ialltoall); + UCC_UNINSTALL_COLL_API(comm, ucc_module, alltoallv); + UCC_UNINSTALL_COLL_API(comm, ucc_module, ialltoallv); + UCC_UNINSTALL_COLL_API(comm, ucc_module, allgather); + UCC_UNINSTALL_COLL_API(comm, ucc_module, iallgather); + UCC_UNINSTALL_COLL_API(comm, ucc_module, allgatherv); + UCC_UNINSTALL_COLL_API(comm, ucc_module, iallgatherv); + UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce); + UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce); + UCC_UNINSTALL_COLL_API(comm, ucc_module, gather); + UCC_UNINSTALL_COLL_API(comm, ucc_module, gatherv); + UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter_block); + UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter); + UCC_UNINSTALL_COLL_API(comm, ucc_module, scatter); + UCC_UNINSTALL_COLL_API(comm, ucc_module, scatterv); + + return OMPI_SUCCESS; +} -#define SET_COLL_PTR(_module, _COLL, _coll) do { \ - _module->super.coll_ ## _coll = NULL; \ - _module->super.coll_i ## _coll = NULL; \ - if ((mca_coll_ucc_component.ucc_lib_attr.coll_types & \ - UCC_COLL_TYPE_ ## _COLL)) { \ - if (mca_coll_ucc_component.cts_requested & \ - UCC_COLL_TYPE_ ## _COLL) { \ - _module->super.coll_ ## _coll = mca_coll_ucc_ ## _coll; \ - } \ - if (mca_coll_ucc_component.nb_cts_requested & \ - UCC_COLL_TYPE_ ## _COLL) { \ - _module->super.coll_i ## _coll = mca_coll_ucc_i ## _coll; \ - } \ - } \ - } while(0) /* * Invoked when there's a new communicator that has been created. @@ -554,23 +557,11 @@ mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority) cm->ucc_enable = 0; return NULL; } - ucc_module->comm = comm; - ucc_module->super.coll_module_enable = mca_coll_ucc_module_enable; - *priority = cm->ucc_priority; - SET_COLL_PTR(ucc_module, BARRIER, barrier); - SET_COLL_PTR(ucc_module, BCAST, bcast); - SET_COLL_PTR(ucc_module, ALLREDUCE, allreduce); - SET_COLL_PTR(ucc_module, ALLTOALL, alltoall); - SET_COLL_PTR(ucc_module, ALLTOALLV, alltoallv); - SET_COLL_PTR(ucc_module, REDUCE, reduce); - SET_COLL_PTR(ucc_module, ALLGATHER, allgather); - SET_COLL_PTR(ucc_module, ALLGATHERV, allgatherv); - SET_COLL_PTR(ucc_module, GATHER, gather); - SET_COLL_PTR(ucc_module, GATHERV, gatherv); - SET_COLL_PTR(ucc_module, REDUCE_SCATTER, reduce_scatter_block); - SET_COLL_PTR(ucc_module, REDUCE_SCATTERV, reduce_scatter); - SET_COLL_PTR(ucc_module, SCATTERV, scatterv); - SET_COLL_PTR(ucc_module, SCATTER, scatter); + ucc_module->comm = comm; + ucc_module->super.coll_module_enable = mca_coll_ucc_module_enable; + ucc_module->super.coll_module_disable = mca_coll_ucc_module_disable; + *priority = cm->ucc_priority; + return &ucc_module->super; } diff --git a/ompi/mca/coll/xhc/Makefile.am b/ompi/mca/coll/xhc/Makefile.am new file mode 100644 index 00000000000..35db0b89c12 --- /dev/null +++ b/ompi/mca/coll/xhc/Makefile.am @@ -0,0 +1,44 @@ +# +# Copyright (c) 2021-2023 Computer Architecture and VLSI Systems (CARV) +# Laboratory, ICS Forth. All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +dist_opaldata_DATA = help-coll-xhc.txt + +sources = \ + coll_xhc.h \ + coll_xhc_atomic.h \ + coll_xhc.c \ + coll_xhc_component.c \ + coll_xhc_module.c \ + coll_xhc_bcast.c \ + coll_xhc_barrier.c \ + coll_xhc_reduce.c \ + coll_xhc_allreduce.c + +# Make the output library in this directory, and name it either +# mca__.la (for DSO builds) or libmca__.la +# (for static builds). + +component_noinst = +component_install = +if MCA_BUILD_ompi_coll_xhc_DSO +component_install += mca_coll_xhc.la +else +component_noinst += libmca_coll_xhc.la +endif + +mcacomponentdir = $(ompilibdir) +mcacomponent_LTLIBRARIES = $(component_install) +mca_coll_xhc_la_SOURCES = $(sources) +mca_coll_xhc_la_LDFLAGS = -module -avoid-version +mca_coll_xhc_la_LIBADD = $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la + +noinst_LTLIBRARIES = $(component_noinst) +libmca_coll_xhc_la_SOURCES = $(sources) +libmca_coll_xhc_la_LDFLAGS = -module -avoid-version diff --git a/ompi/mca/coll/xhc/README.md b/ompi/mca/coll/xhc/README.md new file mode 100644 index 00000000000..325170b7179 --- /dev/null +++ b/ompi/mca/coll/xhc/README.md @@ -0,0 +1,282 @@ +# XHC: XPMEM-based Hierarchical Collectives + +The XHC component, implements hierarchical & topology-aware intra-node MPI +collectives, utilizing XPMEM in order to achieve efficient shared address space +memory access between processes. + +## Main features + +* Constructs an **n-level hierarchy** (i.e. no algorithmic limitation on level +count), following the system's hardware topology. Ranks/processes are grouped +together according to their relative locations; this information is known +thanks to Hwloc, and is obtained via OpenMPI's integrated book-keeping. + + Topological features that can currently be defined (configurable via MCA params): + + - NUMA node + - CPU Socket + - L1, L2, L3 cache + - Hwthread, core + - Node/flat (no hierarchy) + + Example of a 3-level XHC hierarchy (sensitivity to numa & socket locality): + + ![Example of 3-level XHC hierarchy](resources/xhc-hierarchy.svg) + + Furthermore, support for custom virtual user-defined hierarchies is + available, to aid when fine-grained control over the communication pattern + is necessary. + +* Support for both **zero-copy** and **copy-in-copy-out** data transportation. + - Switchover at static but configurable message size. + + - CICO buffers permanently attached at module initialization + + - Application buffers attached on the fly the first time they appear, saved + on and recovered from registration cache in subsequent appearances. + (assuming smsc/xpmem) + +* Integration with Open MPI's `opal/smsc` (shared-memory-single-copy) +framework. Selection of `smsc/xpmem` is highly recommended. + + - Bcast support: XPMEM, CMA, KNEM + - Allreduce support: XPMEM + - Barrier support: *(all, irrelevant)* + +* Data-wise **pipelining** across all levels of the hierarchy allows for +lowering hierarchy-induced start-up overheads. Pipelining also allows for +interleaving of operations in some collectives (reduce+bcast in allreduce). + +* **Lock-free** single-writer synchronization, with cache-line separation where +necessary/beneficial. Consistency ensured via lightweight memory barriers. + +## Configuration options -- MCA params + +XHC can be customized via a number of standard Open MPI MCA parameters, though +defaults that should satisfy a wide number of systems are in place. + +The available parameters: + +#### *(prepend with "coll_xhc_")* +*(list may be outdated, please also check `ompi_info` and `coll_xhc_component.c`)* + +* **priority** (default `0`): The priority of the coll/xhc component, used +during the component selection process. + +* **print_info** (default `false`): Print information about XHC's generated +hierarchy and its configuration. + +* **shmem_backing** (default `/dev/shm`): Backing directory for shmem files +used for XHC's synchronization fields and CICO buffers. + +* **dynamic_leader** (default `false`): Enables the feature that dynamically +elects an XHC group leader at each collective (currently only applicable +to bcast). + +* **dynamic_reduce** (default `1`=`non-float`): Controls the +feature that allows for out-of-order reduction. XHC ranks reduce chunks +directly from multiple peers' buffers; dynamic reduction allows them to +temporarily skip a peer when the expected data is not yet prepared, instead of +stalling. Setting to `2`=`all`, might/will harm reproducibility of float-based +reductions. + +* **coll_xhc_lb_reduce_leader_assist** (default `top,first`): Controls the +leader-to-member load balancing mode in reductions. If set to none/empty (`""`) +only non-leader group members perform reductions. With `top` in the list, the +leader of the top-most level also performs reductions in his group. With +`first` in the list, leaders will help in the reduction workload for just one +chunk at the beginning of the operation. If `all` is specified, all group +members, including the leaders, perform reductions indiscriminately. + +* **force_reduce** (default `false`): Force enable the "special" Reduce +implementation for all calls to MPI_Reduce. This implementation assumes that +the `rbuf` parameter to MPI_Reduce is valid and appropriately sized for all +ranks; not just the root -- you have to make sure that this is indeed the case +with the application at hand. Only works with `root = 0`. + +* **hierarchy** (default `"numa,socket"`): A comma separated list of +topological feature to which XHC's hierarchy-building algorithm should be +sensitive. `ompi_info` reports the possible values for the parameter. + + - In some ways, this is "just" a suggestion. The resulting hierarchy may + not exactly match the requested one. Reasons that this will occur: + + - A requested topological feature does not effectively segment the set + of ranks. (eg. `numa` was specified, but all ranks reside in the same + NUMA node) + + - No feature that all ranks have in common was provided. This a more + intrinsic detail, that you probably don't need to be aware of, but you + might come across if eg. you investigate the output of `print_info`. An + additional level will automatically be added in this case, no need to + worry about it. + + For all intents and purposes, a hierarchy of `numa,socket` is + interpreted as "segment the ranks according to NUMA node locality, + and then further segment them according to CPU socket locality". + Three groups will be created: the intra-NUMA one, the intra-socket + one, and an intra-node one. + + - The provided features will automatically be re-ordered when their + order does not match their order in the physical system. (unless a + virtual hierarchy feature is present in the list) + + - *Virtual Hierarchies*: The string may alternatively also contain "rank + lists" which specify exactly which ranks to group together, as well as some + other special modifiers. See in + `coll_xhc_component.c:xhc_component_parse_hierarchy()` for further + explanation as well as syntax information. + +* **chunk_size** (default `16K`): The chunk size for the pipelining process. +Data is processed (eg broadcast, reduced) in this-much sized pieces at once. + + - It's possible to have a different chunk size for each level of the + hierarchy, achieved via providing a comma-separated list of sizes (eg. + `"16K,16K,128K"`) instead of single one. The sizes in this list's *DO NOT* + correspond to the items on hierarchy list; the hierarchy keys might be + re-ordered or reduced to match the system, but the chunk sizes will be + consumed in the order they are given, left-to-right -> bottom-to-top. + +* **uniform_chunks** (default `true`): Automatically optimize the chunk size +in reduction collectives, according to the message size, so that all members +will perform equal work. + +* **uniform_chunks_min** (default `1K`): The lowest allowed value for the chunk +size when uniform chunks are enabled. Each worker will reduce at least this much +data, or we don't bother splitting the workload up. + +* **cico_max** (default `1K`): Copy-in-copy-out, instead of single-copy, will be +used for messages of *cico_max* or less bytes. + +*(Removed Parameters)* + +* **rcache_max**, **rcache_max_global** *(REMOVED with shift to opal/smsc)*: +Limit to number of attachments that the registration cache should hold. + + - A case can be made about their usefulness. If desired, should be + re-implemented at smsc-level. + +## Limitations + +- *Intra-node support only* + - Usage in multi-node scenarios is possible via OpenMPI's HAN. + +- **Heterogeneity**: XHC does not support nodes with non-uniform (rank-wise) +datatype representations. (determined according to `proc_arch` field) + +- **Non-commutative** operators are not supported by XHC's reduction +collectives. In past versions, they were supported, but only with the flat +hierarchy configuration; this could make a return at some point. + +- XHC's Reduce is not fully complete. Instead, it is a "special" implementation +of MPI_Reduce, that is realized as a sub-case of XHC's Allreduce. + + - If the caller guarantees that the `rbuf` parameter is valid for all ranks + (not just the root), like in Allreduce, this special Reduce can be invoked + by specifying `root=-1`, which will trigger a Reduce to rank `0` (the only + one currently supported). + + - Current prime use-case: HAN's Allreduce + + - Furthermore, if it is guaranteed that all Reduce calls in an application + satisfy the above criteria, see about the `force_reduce` MCA parameter. + + - XHC's Reduce is not yet fully optimized for small messages. + +## Building + +XHC is built as a standard mca/coll component. + +To reap its full benefits, XPMEM support in OpenMPI is required. XHC will build +and work without it, but the reduction operations will be disabled and +broadcast will fall-back to less efficient mechanisms (CMA, KNEM). + +## Running + +In order for the XHC component to be chosen, make sure that its priority is +higher than other components that provide the collectives of interest; use the +`coll_xhc_priority` MCA parameter. If a list of collective modules is included +via the `coll` MCA parameter, make sure XHC is in the list. + +* You may also want to add the `--bind-to core` param. Otherwise, the reported +process localities might be too general, preventing XHC from correctly +segmenting the system. (`coll_xhc_print_info` will report the generated +hierarchy) + +### Tuning + +* Optional: You might wish to manually specify the topological features that +XHC's hierarchy should conform to. The default is `numa,socket`, which will +group the processes according to NUMA locality and then further group them +according to socket locality. See the `coll_xhc_hierarchy` param. + + - Example: `--mca coll_xhc_hierarchy numa,socket` + - Example: `--mca coll_xhc_hierarchy numa` + - Example: `--mca coll_xhc_hierarchy flat` + + In some systems, small-message Broadcast or the Barrier operation might + perform better with a flat tree instead of a hierarchical one. Currently, + manual benchmarking is required to accurately determine this. + +* Optional: You might wish to tune XHC's chunk size (default `16K`). Use the +`coll_xhc_chunk_size` param, and try values close to the default and see if +improvements are observed. You may even try specifying different chunk sizes +for each hierarchy level -- use the same process, starting from the same chunk +size for all levels and decreasing/increasing from there. + + - Example: `--mca coll_xhc_chunk_size 16K` + - Example: `--mca coll_xhc_chunk_size 16K,32K,128K` + +* Optional: If you wish to focus on latencies of small messages, you can try +altering the cico-to-zcopy switchover point (`coll_xhc_cico_max`, default +`1K`). + + - Example: `--mca coll_xhc_cico_max 1K` + +* Optional: If your application is heavy in Broadcast calls and you suspect +that specific ranks might be joining the collective with delay and causing +others to stall waiting for them, you could try enabling dynamic leadership +(`coll_xhc_dynamic_leader`), and seeing if it marks an improvement. + + - Example: `--mca coll_xhc_dynamic_leader 1` + +### Example command lines + +*Assuming `PATH` and `LD_LIBRARY_PATH` have been set appropriately.* + +Default XHC configuration: +`$ mpirun --mca coll libnbc,basic,xhc --mca coll_xhc_priority 100 --bind-to core ` + +XHC w/ numa-sensitive hierarchy, chunk size @ 16K: +`$ mpirun --mca coll libnbc,basic,xhc --mca coll_xhc_priority 100 --mca coll_xhc_hierarchy numa --mca coll_xhc_chunk_size 16K --bind-to core ` + +XHC with flat hierarchy (ie. none at all): +`$ mpirun --mca coll libnbc,basic,xhc --mca coll_xhc_priority 100 --mca coll_xhc_hierarchy node [--bind-to core] ` + +## Publications + +1. **A framework for hierarchical single-copy MPI collectives on multicore nodes**, +*George Katevenis, Manolis Ploumidis, Manolis Marazakis*, +IEEE Cluster 2022, Heidelberg, Germany. +https://ieeexplore.ieee.org/document/9912729 + +## Contact + +- George Katevenis (gkatev@ics.forth.gr) +- Manolis Ploumidis (ploumid@ics.forth.gr) + +Computer Architecture and VLSI Systems (CARV) Laboratory, ICS Forth + +## Acknowledgments + +We thankfully acknowledge the support of the European Commission and the Greek +General Secretariat for Research and Innovation under the EuroHPC Programme +through the **DEEP-SEA** project (GA 955606). National contributions from the +involved state members (including the Greek General Secretariat for Research +and Innovation) match the EuroHPC funding. + +This work is partly supported by project **EUPEX**, which has received funding +from the European High-Performance Computing Joint Undertaking (JU) under grant +agreement No 101033975. The JU receives support from the European Union's +Horizon 2020 re-search and innovation programme and France, Germany, Italy, +Greece, United Kingdom, Czech Republic, Croatia. diff --git a/ompi/mca/coll/xhc/coll_xhc.c b/ompi/mca/coll/xhc/coll_xhc.c new file mode 100644 index 00000000000..d7221ffb37a --- /dev/null +++ b/ompi/mca/coll/xhc/coll_xhc.c @@ -0,0 +1,748 @@ +/* + * Copyright (c) 2021-2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/constants.h" +#include "ompi/communicator/communicator.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/coll/base/base.h" + +#include "opal/mca/rcache/rcache.h" +#include "opal/mca/shmem/base/base.h" +#include "opal/mca/smsc/smsc.h" + +#include "opal/include/opal/align.h" +#include "opal/util/show_help.h" +#include "opal/util/minmax.h" + +#include "coll_xhc.h" + +static int xhc_comms_make(ompi_communicator_t *ompi_comm, + xhc_peer_info_t *peer_info, xhc_comm_t **comms_dst, + int *comm_count_dst, xhc_loc_t *hierarchy, int hierarchy_len); +static void xhc_comms_destroy(xhc_comm_t *comms, int comm_count); + +static int xhc_print_info(xhc_module_t *module, + ompi_communicator_t *comm, xhc_data_t *data); + +static void *xhc_shmem_create(opal_shmem_ds_t *seg_ds, size_t size, + ompi_communicator_t *ompi_comm, const char *name_chr_s, int name_chr_i); +static void *xhc_shmem_attach(opal_shmem_ds_t *seg_ds); +static mca_smsc_endpoint_t *xhc_smsc_ep(xhc_peer_info_t *peer_info); + +// ------------------------------------------------ + +int mca_coll_xhc_lazy_init(xhc_module_t *module, ompi_communicator_t *comm) { + + int comm_size = ompi_comm_size(comm); + int rank = ompi_comm_rank(comm); + + xhc_peer_info_t *peer_info = module->peer_info; + + opal_shmem_ds_t *peer_cico_ds = NULL; + xhc_data_t *data = NULL; + + xhc_coll_fns_t xhc_fns; + + int return_code = OMPI_SUCCESS; + int ret; + + errno = 0; + + // ---- + + /* XHC requires rank communication during its initialization. + * Temporarily apply the saved fallback collective modules, + * and restore XHC's after initialization is done. */ + xhc_module_install_fallback_fns(module, comm, &xhc_fns); + + // ---- + + ret = xhc_module_prepare_hierarchy(module, comm); + if(ret != OMPI_SUCCESS) { + RETURN_WITH_ERROR(return_code, ret, end); + } + + // ---- + + data = malloc(sizeof(xhc_data_t)); + peer_cico_ds = malloc(comm_size * sizeof(opal_shmem_ds_t)); + if(!data || !peer_cico_ds) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_OUT_OF_RESOURCE, end); + } + + *data = (xhc_data_t) { + .comms = NULL, + .comm_count = -1, + + .pvt_coll_seq = 0 + }; + + // ---- + + if(OMPI_XHC_CICO_MAX > 0) { + opal_shmem_ds_t cico_ds; + + void *my_cico = xhc_shmem_create(&cico_ds, + OMPI_XHC_CICO_MAX, comm, "cico", 0); + if(!my_cico) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_OUT_OF_RESOURCE, end); + } + + /* Manually "touch" to assert allocation in local NUMA node + * (assuming linux's default firt-touch-alloc NUMA policy) */ + memset(my_cico, 0, OMPI_XHC_CICO_MAX); + + ret = comm->c_coll->coll_allgather(&cico_ds, + sizeof(opal_shmem_ds_t), MPI_BYTE, peer_cico_ds, + sizeof(opal_shmem_ds_t), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + if(ret != OMPI_SUCCESS) { + RETURN_WITH_ERROR(return_code, ret, end); + } + + for(int r = 0; r < comm_size; r++) { + peer_info[r].cico_ds = peer_cico_ds[r]; + } + + peer_info[rank].cico_buffer = my_cico; + } + + // ---- + + /* An XHC communicator is created for each level of the hierarchy. + * The hierachy must be in an order of most-specific to most-general. */ + + ret = xhc_comms_make(comm, peer_info, &data->comms, &data->comm_count, + module->hierarchy, module->hierarchy_len); + if(ret != OMPI_SUCCESS) { + RETURN_WITH_ERROR(return_code, ret, end); + } + + for(int i = 0, c = 0; i < data->comm_count; i++) { + data->comms[i].chunk_size = module->chunks[c]; + c = opal_min(c + 1, module->chunks_len - 1); + } + + if(module->chunks_len < data->comm_count) { + opal_output_verbose(MCA_BASE_VERBOSE_WARN, + ompi_coll_base_framework.framework_output, + "coll:xhc: Warning: The chunk sizes count is shorter than the " + "hierarchy size; filling in with the last entry provided"); + } else if(module->chunks_len > data->comm_count) { + opal_output_verbose(MCA_BASE_VERBOSE_WARN, + ompi_coll_base_framework.framework_output, + "coll:xhc: Warning: The chunk size count is larger than the " + "hierarchy size; omitting last entries"); + } + + // ---- + + if(mca_coll_xhc_component.print_info) { + ret = xhc_print_info(module, comm, data); + if(ret != OMPI_SUCCESS) { + RETURN_WITH_ERROR(return_code, ret, end); + } + } + + // ---- + + module->data = data; + module->init = true; + + end: + + xhc_module_install_fns(module, comm, xhc_fns); + + free(peer_cico_ds); + + if(return_code != 0) { + opal_show_help("help-coll-xhc.txt", "xhc-init-failed", true, + return_code, errno, strerror(errno)); + + xhc_fini(module); + } + + return return_code; +} + +void mca_coll_xhc_fini(mca_coll_xhc_module_t *module) { + if(module->data) { + xhc_data_t *data = module->data; + + if(data->comm_count >= 0) { + xhc_comms_destroy(data->comms, data->comm_count); + } + + free(data->comms); + free(data); + } + + if(module->peer_info) { + for(int r = 0; r < module->comm_size; r++) { + if(module->peer_info[r].cico_buffer) { + if(r == module->rank) { + // OMPI issue #11123 + // opal_shmem_unlink(&module->peer_info[r].cico_ds); + } + + opal_shmem_segment_detach(&module->peer_info[r].cico_ds); + } + + if(module->peer_info[r].smsc_ep) { + MCA_SMSC_CALL(return_endpoint, module->peer_info[r].smsc_ep); + } + } + } +} + +// ------------------------------------------------ + +/* This method is where the hierarchy of XHC is constructed; it receives + * the hierarchy specifications (hierarchy param) and groups ranks together + * among them. The process begins with the first locality in the list. All + * ranks that share this locality (determined via the relative peer to peer + * distances) become siblings. The one amongst them with the lowest rank + * number becomes the manager/leader of the group. The members don't really + * need to keep track of the actual ranks of their siblings -- only the rank + * of the group's leader/manager, the size of the group, and their own member + * ID. The process continues with the next locality, only that now only the + * ranks that became leaders in the previous level are eligible (determined + * via comm_candidate, see inline comments). */ +static int xhc_comms_make(ompi_communicator_t *ompi_comm, + xhc_peer_info_t *peer_info, xhc_comm_t **comms_dst, + int *comm_count_dst, xhc_loc_t *hierarchy, int hierarchy_len) { + + int ompi_rank = ompi_comm_rank(ompi_comm); + int ompi_size = ompi_comm_size(ompi_comm); + + xhc_comm_t *comms = NULL; + int comms_size = 0; + int comm_count = 0; + + opal_shmem_ds_t *comm_ctrl_ds; + bool *comm_candidate; + + size_t smsc_reg_size = 0; + + int return_code = OMPI_SUCCESS; + int ret; + + comms = malloc((comms_size = 5) * sizeof(xhc_comm_t)); + comm_ctrl_ds = malloc(ompi_size * sizeof(opal_shmem_ds_t)); + comm_candidate = malloc(ompi_size * sizeof(bool)); + + if(!comms || !comm_ctrl_ds || !comm_candidate) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_OUT_OF_RESOURCE, end); + } + + if(mca_smsc_base_has_feature(MCA_SMSC_FEATURE_REQUIRE_REGISTRATION)) { + smsc_reg_size = mca_smsc_base_registration_data_size(); + } + + for(int h = 0; h < hierarchy_len; h++) { + xhc_comm_t *xc = &comms[comm_count]; + + if(comm_count == comms_size) { + void *tmp = realloc(comms, (comms_size *= 2) * sizeof(xhc_comm_t)); + if(!tmp) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_OUT_OF_RESOURCE, end); + } + comms = tmp; + } + + *xc = (xhc_comm_t) { + .locality = hierarchy[h], + + .size = 0, + .manager_rank = -1, + + .member_info = NULL, + .reduce_queue = NULL, + + .comm_ctrl = NULL, + .member_ctrl = NULL, + + .ctrl_ds = (opal_shmem_ds_t) {0} + }; + + // ---- + + /* Only ranks that were leaders in the previous level are candidates + * for this one. Every rank advertises whether others may consider + * it for inclusion via an Allgather. */ + + bool is_candidate = (comm_count == 0 + || comms[comm_count - 1].manager_rank == ompi_rank); + + ret = ompi_comm->c_coll->coll_allgather(&is_candidate, 1, + MPI_C_BOOL, comm_candidate, 1, MPI_C_BOOL, + ompi_comm, ompi_comm->c_coll->coll_allgather_module); + if(ret != OMPI_SUCCESS) { + RETURN_WITH_ERROR(return_code, ret, comm_error); + } + + for(int r = 0; r < ompi_size; r++) { + + /* If on a non-bottom comm, only managers of the previous + * comm are "full" members. However, this procedure also has + * to take place for the bottom-most comm; even if this is the + * current rank's bottom-most comm, it may not actually be so, + * for another rank (eg. with some non-symmetric hierarchies). */ + if(comm_candidate[r] == false) { + continue; + } + + // Non-local --> not part of the comm :/ + if(!PEER_IS_LOCAL(peer_info, r, xc->locality)) { + continue; + } + + /* The member ID means slightly different things whether on the + * bottom-most comm or not. On the bottom-most comm, a rank can + * either be a "full" member or not. However, on higher-up comms, + * if a rank was not a manager on the previous comm, it will not + * a "full" member. Instead, it will be a "potential" member, in + * that it keeps information about this comm, and is ready to + * take over duties and act as a normal member for a specific + * collective (eg. dynamic leader feature, or root != manager). */ + if(r == ompi_rank || (comm_count > 0 && r == comms[comm_count - 1].manager_rank)) { + xc->member_id = xc->size; + } + + // First rank to join the comm becomes the manager + if(xc->manager_rank == -1) { + xc->manager_rank = r; + } + + xc->size++; + } + + /* If there are no local peers in regards to this locality, no + * XHC comm is created for this process on this level. */ + if(xc->size <= 1) { + opal_output_verbose(MCA_BASE_VERBOSE_WARN, + ompi_coll_base_framework.framework_output, + "coll:xhc: Warning: Locality 0x%04x does not result " + "in any new groupings; skipping it", xc->locality); + + /* All ranks must participate in the "control struct sharing" + * allgather, even if useless to this rank to some of them */ + + ret = ompi_comm->c_coll->coll_allgather(&xc->ctrl_ds, + sizeof(opal_shmem_ds_t), MPI_BYTE, comm_ctrl_ds, + sizeof(opal_shmem_ds_t), MPI_BYTE, ompi_comm, + ompi_comm->c_coll->coll_allgather_module); + if(ret != OMPI_SUCCESS) { + RETURN_WITH_ERROR(return_code, ret, comm_error); + } + + xhc_comms_destroy(xc, 1); + continue; + } + + // ---- + + /* Init comm stuff */ + + xc->member_info = calloc(xc->size, sizeof(xhc_member_info_t)); + if(xc->member_info == NULL) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_OUT_OF_RESOURCE, comm_error); + } + + xc->reduce_queue = OBJ_NEW(opal_list_t); + if(!xc->reduce_queue) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_OUT_OF_RESOURCE, comm_error); + } + + for(int m = 0; m < xc->size - 1; m++) { + xhc_rq_item_t *item = OBJ_NEW(xhc_rq_item_t); + if(!item) { + RETURN_WITH_ERROR(return_code, + OMPI_ERR_OUT_OF_RESOURCE, comm_error); + } + + opal_list_append(xc->reduce_queue, (opal_list_item_t *) item); + } + + // ---- + + // Create shared structs + if(ompi_rank == xc->manager_rank) { + size_t ctrl_len = sizeof(xhc_comm_ctrl_t) + smsc_reg_size + + xc->size * sizeof(xhc_member_ctrl_t); + + char *ctrl_base = xhc_shmem_create(&xc->ctrl_ds, ctrl_len, + ompi_comm, "ctrl", comm_count); + if(ctrl_base == NULL) { + RETURN_WITH_ERROR(return_code, OMPI_ERROR, comm_error); + } + + /* Manually "touch" to assert allocation in local NUMA node + * (assuming linux's default firt-touch-alloc NUMA policy) */ + memset(ctrl_base, 0, ctrl_len); + + xc->comm_ctrl = (void *) ctrl_base; + xc->member_ctrl = (void *) (ctrl_base + + sizeof(xhc_comm_ctrl_t) + smsc_reg_size); + } + + /* The comm's managers share the details of the communication structs + * with their children, so that they may attach to them. Because + * there's not any MPI communicator formed that includes (only) the + * members of the XHC comm, the sharing is achieved with a single + * Allgather, instead of a Broadcast inside each XHC comm. */ + + ret = ompi_comm->c_coll->coll_allgather(&xc->ctrl_ds, + sizeof(opal_shmem_ds_t), MPI_BYTE, comm_ctrl_ds, + sizeof(opal_shmem_ds_t), MPI_BYTE, ompi_comm, + ompi_comm->c_coll->coll_allgather_module); + if(ret != OMPI_SUCCESS) { + RETURN_WITH_ERROR(return_code, ret, comm_error); + } + + // Attach to manager's shared structs + if(ompi_rank != xc->manager_rank) { + xc->ctrl_ds = comm_ctrl_ds[xc->manager_rank]; + + char *ctrl_base = xhc_shmem_attach(&xc->ctrl_ds); + if(ctrl_base == NULL) { + RETURN_WITH_ERROR(return_code, OMPI_ERROR, comm_error); + } + + xc->comm_ctrl = (void *) ctrl_base; + xc->member_ctrl = (void *) (ctrl_base + + sizeof(xhc_comm_ctrl_t) + smsc_reg_size); + } + + xc->my_member_ctrl = &xc->member_ctrl[xc->member_id]; + xc->my_member_info = &xc->member_info[xc->member_id]; + + // ---- + + comm_count++; + + continue; + + comm_error: { + xhc_comms_destroy(comms, comm_count+1); + comm_count = -1; + + goto end; + } + } + + REALLOC(comms, comm_count, xhc_comm_t); + + *comms_dst = comms; + *comm_count_dst = comm_count; + + end: + + free(comm_ctrl_ds); + free(comm_candidate); + + if(return_code != OMPI_SUCCESS) { + free(comms); + } + + return return_code; +} + +static void xhc_comms_destroy(xhc_comm_t *comms, int comm_count) { + bool is_manager = true; + + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + if(xc->member_id != 0) { + is_manager = false; + } + + free(xc->member_info); + + if(xc->reduce_queue) { + OPAL_LIST_RELEASE(xc->reduce_queue); + } + + if(xc->comm_ctrl) { + if(is_manager) { + // OMPI issue #11123 + // opal_shmem_unlink(&xc->ctrl_ds); + (void) is_manager; + } + + opal_shmem_segment_detach(&xc->ctrl_ds); + } + + *xc = (xhc_comm_t) {0}; + } +} + +static int xhc_print_info(xhc_module_t *module, + ompi_communicator_t *comm, xhc_data_t *data) { + + int rank = ompi_comm_rank(comm); + int ret; + + if(rank == 0) { + char *drval_str; + char *lb_rla_str; + char *un_min_str; + + switch(mca_coll_xhc_component.dynamic_reduce) { + case OMPI_XHC_DYNAMIC_REDUCE_DISABLED: + drval_str = "OFF"; break; + case OMPI_XHC_DYNAMIC_REDUCE_NON_FLOAT: + drval_str = "ON (non-float)"; break; + case OMPI_XHC_DYNAMIC_REDUCE_ALL: + drval_str = "ON (all)"; break; + default: + drval_str = "???"; + } + + switch(mca_coll_xhc_component.lb_reduce_leader_assist) { + case OMPI_XHC_LB_RLA_TOP_LEVEL: + lb_rla_str = "top level"; break; + case OMPI_XHC_LB_RLA_FIRST_CHUNK: + lb_rla_str = "first chunk"; break; + case OMPI_XHC_LB_RLA_TOP_LEVEL | OMPI_XHC_LB_RLA_FIRST_CHUNK: + lb_rla_str = "top level + first chunk"; break; + case OMPI_XHC_LB_RLA_ALL: + lb_rla_str = "all"; break; + default: + lb_rla_str = "???"; + } + + ret = opal_asprintf(&un_min_str, " (min '%zu' bytes)", + mca_coll_xhc_component.uniform_chunks_min); + if(ret < 0) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + + printf("------------------------------------------------\n" + "OMPI coll/xhc @ %s, priority %d\n" + " dynamic leader '%s', dynamic reduce '%s'\n" + " reduce load-balancing leader-assist '%s'\n" + " allreduce uniform chunks '%s'%s\n" + " CICO up until %zu bytes, barrier root %d\n\n" + "------------------------------------------------\n", + comm->c_name, mca_coll_xhc_component.priority, + (mca_coll_xhc_component.dynamic_leader ? "ON" : "OFF"), + drval_str, lb_rla_str, + (mca_coll_xhc_component.uniform_chunks ? "ON" : "OFF"), + (mca_coll_xhc_component.uniform_chunks ? un_min_str : ""), + mca_coll_xhc_component.cico_max, + mca_coll_xhc_component.barrier_root); + + free(un_min_str); + } + + for(int i = 0; i < data->comm_count; i++) { + char *mlist = NULL; + char *tmp; + + ret = opal_asprintf(&mlist, "%d", data->comms[i].manager_rank); + if(ret < 0) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + + for(int m = 1; m < data->comms[i].size; m++) { + if(m == data->comms[i].member_id) { + if(i == 0 || data->comms[i-1].manager_rank == rank) { + ret = opal_asprintf(&tmp, "%s %d", mlist, rank); + } else { + ret = opal_asprintf(&tmp, "%s _", mlist); + } + } else { + ret = opal_asprintf(&tmp, "%s x", mlist); + } + + free(mlist); + mlist = tmp; + + if(ret < 0) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + } + + printf("XHC comm loc=0x%08x chunk_size=%zu with %d members [%s]\n", + data->comms[i].locality, data->comms[i].chunk_size, + data->comms[i].size, mlist); + + free(mlist); + } + + return OMPI_SUCCESS; +} + +// ------------------------------------------------ + +static void *xhc_shmem_create(opal_shmem_ds_t *seg_ds, size_t size, + ompi_communicator_t *ompi_comm, const char *name_chr_s, int name_chr_i) { + + char *shmem_file; + int ret; + + // xhc_shmem_seg.@..:_: + + ret = opal_asprintf(&shmem_file, "%s" OPAL_PATH_SEP "xhc_shmem_seg.%u@%s.%x.%d:%d_%s:%d", + mca_coll_xhc_component.shmem_backing, geteuid(), opal_process_info.nodename, + OPAL_PROC_MY_NAME.jobid, ompi_comm_rank(MPI_COMM_WORLD), ompi_comm_get_local_cid(ompi_comm), + name_chr_s, name_chr_i); + + if(ret < 0) { + return NULL; + } + + // Not 100% sure what this does!, copied from btl/sm + opal_pmix_register_cleanup(shmem_file, false, false, false); + + ret = opal_shmem_segment_create(seg_ds, shmem_file, size); + + free(shmem_file); + + if(ret != OPAL_SUCCESS) { + opal_output_verbose(MCA_BASE_VERBOSE_ERROR, + ompi_coll_base_framework.framework_output, + "coll:xhc: Error: Could not create shared memory segment"); + + return NULL; + } + + void *addr = xhc_shmem_attach(seg_ds); + + if(addr == NULL) { + opal_shmem_unlink(seg_ds); + } + + return addr; +} + +static void *xhc_shmem_attach(opal_shmem_ds_t *seg_ds) { + void *addr = opal_shmem_segment_attach(seg_ds); + + if(addr == NULL) { + opal_output_verbose(MCA_BASE_VERBOSE_ERROR, + ompi_coll_base_framework.framework_output, + "coll:xhc: Error: Could not attach to shared memory segment"); + } + + return addr; +} + +static mca_smsc_endpoint_t *xhc_smsc_ep(xhc_peer_info_t *peer_info) { + if(!peer_info->smsc_ep) { + peer_info->smsc_ep = MCA_SMSC_CALL(get_endpoint, &peer_info->proc->super); + + if(!peer_info->smsc_ep) { + opal_output_verbose(MCA_BASE_VERBOSE_ERROR, + ompi_coll_base_framework.framework_output, + "coll:xhc: Error: Failed to initialize smsc endpoint"); + + return NULL; + } + } + + return peer_info->smsc_ep; +} + +// ------------------------------------------------ + +void *mca_coll_xhc_get_cico(xhc_peer_info_t *peer_info, int rank) { + if(OMPI_XHC_CICO_MAX == 0) { + return NULL; + } + + if(peer_info[rank].cico_buffer == NULL) { + peer_info[rank].cico_buffer = xhc_shmem_attach(&peer_info[rank].cico_ds); + } + + return peer_info[rank].cico_buffer; +} + +int mca_coll_xhc_copy_expose_region(void *base, size_t len, xhc_copy_data_t **region_data) { + if(mca_smsc_base_has_feature(MCA_SMSC_FEATURE_REQUIRE_REGISTRATION)) { + void *data = MCA_SMSC_CALL(register_region, base, len); + + if(data == NULL) { + opal_output_verbose(MCA_BASE_VERBOSE_ERROR, + ompi_coll_base_framework.framework_output, + "coll:xhc: Error: Failed to register memory region with smsc"); + + return -1; + } + + *region_data = data; + } + + return 0; +} + +void mca_coll_xhc_copy_region_post(void *dst, xhc_copy_data_t *region_data) { + memcpy(dst, region_data, mca_smsc_base_registration_data_size()); +} + +int mca_coll_xhc_copy_from(xhc_peer_info_t *peer_info, + void *dst, void *src, size_t size, void *access_token) { + + mca_smsc_endpoint_t *smsc_ep = xhc_smsc_ep(peer_info); + + if(smsc_ep == NULL) { + return -1; + } + + int status = MCA_SMSC_CALL(copy_from, smsc_ep, + dst, src, size, access_token); + + return (status == OPAL_SUCCESS ? 0 : -1); +} + +void mca_coll_xhc_copy_close_region(xhc_copy_data_t *region_data) { + if(mca_smsc_base_has_feature(MCA_SMSC_FEATURE_REQUIRE_REGISTRATION)) + MCA_SMSC_CALL(deregister_region, region_data); +} + +void *mca_coll_xhc_get_registration(xhc_peer_info_t *peer_info, + void *peer_vaddr, size_t size, xhc_reg_t **reg) { + + mca_smsc_endpoint_t *smsc_ep = xhc_smsc_ep(peer_info); + + if(smsc_ep == NULL) { + return NULL; + } + + /* MCA_RCACHE_FLAGS_PERSIST will cause the registration to stick around. + * Though actually, because smsc/xpmem initializes the ref count to 2, + * as a means of keeping the registration around (instead of using the + * flag), our flag here doesn't have much effect. If at some point we + * would wish to actually detach memory in some or all cases, we should + * either call the unmap method twice, or reach out to Open MPI devs and + * inquire about the ref count. */ + + void *local_ptr; + + *reg = MCA_SMSC_CALL(map_peer_region, smsc_ep, + MCA_RCACHE_FLAGS_PERSIST, peer_vaddr, size, &local_ptr); + + if(*reg == NULL) { + return NULL; + } + + return local_ptr; +} + +/* Won't actually unmap/detach, since we've set + * the "persist" flag while creating the mapping */ +void mca_coll_xhc_return_registration(xhc_reg_t *reg) { + MCA_SMSC_CALL(unmap_peer_region, reg); +} diff --git a/ompi/mca/coll/xhc/coll_xhc.h b/ompi/mca/coll/xhc/coll_xhc.h new file mode 100644 index 00000000000..0de32f03b46 --- /dev/null +++ b/ompi/mca/coll/xhc/coll_xhc.h @@ -0,0 +1,514 @@ +/* + * Copyright (c) 2021-2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef MCA_COLL_XHC_EXPORT_H +#define MCA_COLL_XHC_EXPORT_H + +#include "ompi_config.h" + +#include +#include + +#include "mpi.h" + +#include "ompi/mca/mca.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/coll/base/base.h" +#include "ompi/communicator/communicator.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/op/op.h" + +#include "opal/mca/shmem/shmem.h" +#include "opal/mca/smsc/smsc.h" + +#include "coll_xhc_atomic.h" + +#define RETURN_WITH_ERROR(var, err, label) do {(var) = (err); goto label;} \ + while(0) + +#define OBJ_RELEASE_IF_NOT_NULL(obj) do {if((obj) != NULL) OBJ_RELEASE(obj);} while(0) + +#define REALLOC(p, s, t) do {void *_tmp = realloc(p, (s)*sizeof(t)); \ + if(_tmp) (p) = _tmp;} while(0) + +#define PEER_IS_LOCAL(peer_info, rank, loc) \ + (((peer_info)[(rank)].locality & (loc)) == (loc)) + +#define OMPI_XHC_LOC_EXT_BITS (8*(sizeof(xhc_loc_t) - sizeof(opal_hwloc_locality_t))) +#define OMPI_XHC_LOC_EXT_START (8*sizeof(opal_hwloc_locality_t)) + +// --- + +#define OMPI_XHC_ACK_WIN 0 + +// Align to CPU cache line (portable way to obtain it?) +#define OMPI_XHC_ALIGN 64 + +// Call opal_progress every this many ticks when busy-waiting +#define OMPI_XHC_OPAL_PROGRESS_CYCLE 10000 + +/* Reduction leader-member load balancing, AKA should leaders reduce data? + * Normally, non-leaders reduce and leaders propagate. But there are instances + * where leaders can/should also help with the group's reduction load. + * + * OMPI_XHC_LB_RLA_TOP_LEVEL: The top level's leader performs reductions + * on the top level as if a common member + * + * OMPI_XHC_LB_RLA_FIRST_CHUNK: Leaders reduce only a single chunk, on + * each level, at the beginning of the operation + * + * (OMPI_XHC_LB_RLA_TOP_LEVEL and OMPI_XHC_LB_RLA_FIRST_CHUNK are combinable) + * + * OMPI_XHC_LB_RLM_ALL: All leaders performs reductions exactly as if + * common members + * + * Generally, we might not want leaders reducing, as that may lead to load + * imbalance, since they will also have to reduce the comm's result(s) + * on upper levels. Unless a leader is also one on all levels! (e.g. the + * top-level leader). This leader should probably be assisting in the + * reduction; otherwise, the only thing he will be doing is checking + * and updating synchronization flags. + * + * Regarding the load balancing problem, the leaders will actually not have + * anything to do until the first chunk is reduced, so they might as well be + * made to help the other members with this first chunk. Keep in mind though, + * this might increase the memory load, and cause this first chunk to take + * slightly more time to be produced. */ +#define OMPI_XHC_LB_RLA_TOP_LEVEL 0x01 +#define OMPI_XHC_LB_RLA_FIRST_CHUNK 0x02 +#define OMPI_XHC_LB_RLA_ALL 0x80 + +enum { + OMPI_XHC_DYNAMIC_REDUCE_DISABLED, + OMPI_XHC_DYNAMIC_REDUCE_NON_FLOAT, + OMPI_XHC_DYNAMIC_REDUCE_ALL +}; + +#define OMPI_XHC_CICO_MAX (mca_coll_xhc_component.cico_max) + +/* For other configuration options and default + * values check coll_xhc_component.c */ + +// --- + +BEGIN_C_DECLS + +// ---------------------------------------- + +typedef uint32_t xhc_loc_t; +typedef void xhc_reg_t; +typedef void xhc_copy_data_t; + +typedef struct mca_coll_xhc_component_t mca_coll_xhc_component_t; +typedef struct mca_coll_xhc_module_t mca_coll_xhc_module_t; +typedef struct mca_coll_xhc_module_t xhc_module_t; + +typedef struct xhc_coll_fns_t xhc_coll_fns_t; +typedef struct xhc_peer_info_t xhc_peer_info_t; + +typedef struct xhc_data_t xhc_data_t; +typedef struct xhc_comm_t xhc_comm_t; + +typedef struct xhc_comm_ctrl_t xhc_comm_ctrl_t; +typedef struct xhc_member_ctrl_t xhc_member_ctrl_t; +typedef struct xhc_member_info_t xhc_member_info_t; + +typedef struct xhc_reduce_area_t xhc_reduce_area_t; +typedef struct xhc_reduce_queue_item_t xhc_rq_item_t; + +typedef struct xhc_rank_range_t xhc_rank_range_t; +typedef struct xhc_loc_def_t xhc_loc_def_t; + +OMPI_DECLSPEC extern mca_coll_xhc_component_t mca_coll_xhc_component; +OMPI_DECLSPEC OBJ_CLASS_DECLARATION(mca_coll_xhc_module_t); +OMPI_DECLSPEC OBJ_CLASS_DECLARATION(xhc_rq_item_t); +OMPI_DECLSPEC OBJ_CLASS_DECLARATION(xhc_loc_def_item_t); + +// ---------------------------------------- + +struct xhc_coll_fns_t { + mca_coll_base_module_allreduce_fn_t coll_allreduce; + mca_coll_base_module_t *coll_allreduce_module; + + mca_coll_base_module_barrier_fn_t coll_barrier; + mca_coll_base_module_t *coll_barrier_module; + + mca_coll_base_module_bcast_fn_t coll_bcast; + mca_coll_base_module_t *coll_bcast_module; + + mca_coll_base_module_reduce_fn_t coll_reduce; + mca_coll_base_module_t *coll_reduce_module; +}; + +struct mca_coll_xhc_component_t { + mca_coll_base_component_t super; + + int priority; + bool print_info; + + char *shmem_backing; + + bool dynamic_leader; + + int barrier_root; + + int dynamic_reduce; + int lb_reduce_leader_assist; + + bool force_reduce; + + bool uniform_chunks; + size_t uniform_chunks_min; + + size_t cico_max; + + char *hierarchy_mca; + char *chunk_size_mca; +}; + +struct mca_coll_xhc_module_t { + mca_coll_base_module_t super; + + /* pointers to functions/modules of + * previous coll components for fallback */ + xhc_coll_fns_t prev_colls; + + // copied from comm + int comm_size; + int rank; + + // list of localities to consider during grouping + char *hierarchy_string; + xhc_loc_t *hierarchy; + int hierarchy_len; + + // list of requested chunk sizes, to be applied to comms + size_t *chunks; + int chunks_len; + + // temporary (private) internal buffer, for methods like Reduce + void *rbuf; + size_t rbuf_size; + + // xhc-specific info for every other rank in the comm + xhc_peer_info_t *peer_info; + + xhc_data_t *data; + + bool init; +}; + +struct xhc_peer_info_t { + xhc_loc_t locality; + + ompi_proc_t *proc; + mca_smsc_endpoint_t *smsc_ep; + + opal_shmem_ds_t cico_ds; + void *cico_buffer; +}; + +struct xhc_data_t { + xhc_comm_t *comms; + int comm_count; + + xf_sig_t pvt_coll_seq; +}; + +struct xhc_comm_t { + xhc_loc_t locality; + size_t chunk_size; + + int size; + int manager_rank; + int member_id; + + // --- + + // Am I a leader in the current collective? + bool is_coll_leader; + + // Have handshaked with all members in the current op? (useful to leader) + bool all_joined; + + /* A reduce set defines a range/area of data to be reduced, and its + * settings. We require multiple areas, because there might be different + * circumstances: + * + * 1. Under certain load balancing policies, leaders perform reductions + * for the just one chunk, and then they don't. Thus, the worker count + * changes, and the settings have to recomputed for the next areas. + * + * 2. During the "middle" of the operation, all members continuously + * reduce data in maximum-sized pieces (according to the configured + * chunk size). But, towards the end of the operation, the remaining + * elements are less than ((workers * elem_chunk)), we have to + * recalculate `elem_chunk`, so that all workers will perform + * equal work. */ + struct xhc_reduce_area_t { + int start; // where the area begins + int len; // the size of the area + int workers; // how many processes perform reductions in the area + int stride; /* how much to advance inside the area after + * each reduction, unused for non-combo areas */ + + // local process settings + int work_begin; // where to begin the first reduction from + int work_end; // up to where to reduce + int work_chunk; // how much to reduce each time + int work_leftover; /* assigned leftover elements to include as + * part of the last reduction in the area */ + } reduce_area[3]; + int n_reduce_areas; + + struct xhc_member_info_t { + xhc_reg_t *sbuf_reg, *rbuf_reg; + void *sbuf, *rbuf; + bool init; + } *member_info; + + // Queue to keep track of individual reduction progress for different peers + opal_list_t *reduce_queue; + + // --- + + xhc_comm_ctrl_t *comm_ctrl; + xhc_member_ctrl_t *member_ctrl; + + opal_shmem_ds_t ctrl_ds; + + // --- + + xhc_member_ctrl_t *my_member_ctrl; // = &member_ctrl[member_id] + xhc_member_info_t *my_member_info; // = &member_info[member_id] +}; + +struct xhc_comm_ctrl_t { + // We want leader_seq, coll_ack, coll_seq to all lie in their own cache lines + + volatile xf_sig_t leader_seq; + + volatile xf_sig_t coll_ack __attribute__((aligned(OMPI_XHC_ALIGN))); + + volatile xf_sig_t coll_seq __attribute__((aligned(OMPI_XHC_ALIGN))); + + /* - Reason *NOT* to keep below fields in the same cache line as coll_seq: + * + * While members busy-wait on leader's coll_seq, initializing the rest of + * the fields will trigger cache-coherency-related "invalidate" and then + * "read miss" messages, for each store. + * + * - Reason to *DO* keep below fields in the same cache line as coll_seq: + * + * Members load from coll_seq, and implicitly fetch the entire cache + * line, which also contains the values of the other fields, that will + * also need to be loaded soon. + * + * (not 100% sure of my description here) + * + * Bcast seemed to perform better with the second option, so I went with + * that one. The best option might also be influenced by the ranks' order + * of entering in the operation. + */ + + // "Guarded" by members' coll_seq + volatile int leader_id; + volatile int leader_rank; + volatile int cico_id; + + void* volatile data_vaddr; + volatile xf_size_t bytes_ready; + + char access_token[]; +} __attribute__((aligned(OMPI_XHC_ALIGN))); + +struct xhc_member_ctrl_t { + volatile xf_sig_t member_ack; // written by member + + // written by member, at beginning of operation + volatile xf_sig_t member_seq __attribute__((aligned(OMPI_XHC_ALIGN))); + volatile int rank; + + void* volatile sbuf_vaddr; + void* volatile rbuf_vaddr; + volatile int cico_id; + + // reduction progress counters, written by member + volatile xf_int_t reduce_ready; + volatile xf_int_t reduce_done; +} __attribute__((aligned(OMPI_XHC_ALIGN))); + +struct xhc_reduce_queue_item_t { + opal_list_item_t super; + int member; // ID of member + int count; // current reduction progress for member + int area_id; // current reduce area +}; + +// ---------------------------------------- + +struct xhc_rank_range_t { + int start_rank, end_rank; +}; + +struct xhc_loc_def_t { + opal_list_item_t super; + + opal_hwloc_locality_t named_loc; + + xhc_rank_range_t *rank_list; + int rank_list_len; + + int split; + int max_ranks; + + bool repeat; +}; + +// ---------------------------------------- + +// coll_xhc_component.c +// -------------------- + +#define xhc_component_parse_hierarchy(...) mca_coll_xhc_component_parse_hierarchy(__VA_ARGS__) +#define xhc_component_parse_chunk_sizes(...) mca_coll_xhc_component_parse_chunk_sizes(__VA_ARGS__) + +int mca_coll_xhc_component_init_query(bool enable_progress_threads, + bool enable_mpi_threads); + +int mca_coll_xhc_component_parse_hierarchy(const char *val_str, + opal_list_t **level_defs_dst, int *nlevel_defs_dst); +int mca_coll_xhc_component_parse_chunk_sizes(const char *val_str, + size_t **vals_dst, int *len_dst); + +// coll_xhc_module.c +// ----------------- + +#define xhc_module_install_fns(...) mca_coll_xhc_module_install_fns(__VA_ARGS__) +#define xhc_module_install_fallback_fns(...) mca_coll_xhc_module_install_fallback_fns(__VA_ARGS__) + +#define xhc_module_prepare_hierarchy(...) mca_coll_xhc_module_prepare_hierarchy(__VA_ARGS__) + +mca_coll_base_module_t *mca_coll_xhc_module_comm_query( + ompi_communicator_t *comm, int *priority); + +int mca_coll_xhc_module_enable(mca_coll_base_module_t *module, + ompi_communicator_t *comm); +int mca_coll_xhc_module_disable(mca_coll_base_module_t *module, + ompi_communicator_t *comm); + +void mca_coll_xhc_module_install_fallback_fns(xhc_module_t *module, + ompi_communicator_t *comm, xhc_coll_fns_t *prev_fns_dst); +void mca_coll_xhc_module_install_fns(xhc_module_t *module, + ompi_communicator_t *comm, xhc_coll_fns_t fns); + +int mca_coll_xhc_module_prepare_hierarchy(mca_coll_xhc_module_t *module, + ompi_communicator_t *comm); + +// coll_xhc.c +// ---------- + +#define xhc_lazy_init(...) mca_coll_xhc_lazy_init(__VA_ARGS__) +#define xhc_fini(...) mca_coll_xhc_fini(__VA_ARGS__) + +#define xhc_get_cico(...) mca_coll_xhc_get_cico(__VA_ARGS__) + +#define xhc_copy_expose_region(...) mca_coll_xhc_copy_expose_region(__VA_ARGS__) +#define xhc_copy_region_post(...) mca_coll_xhc_copy_region_post(__VA_ARGS__) +#define xhc_copy_from(...) mca_coll_xhc_copy_from(__VA_ARGS__) +#define xhc_copy_close_region(...) mca_coll_xhc_copy_close_region(__VA_ARGS__) + +#define xhc_get_registration(...) mca_coll_xhc_get_registration(__VA_ARGS__) +#define xhc_return_registration(...) mca_coll_xhc_return_registration(__VA_ARGS__) + +int mca_coll_xhc_lazy_init(mca_coll_xhc_module_t *module, ompi_communicator_t *comm); +void mca_coll_xhc_fini(mca_coll_xhc_module_t *module); + +void *mca_coll_xhc_get_cico(xhc_peer_info_t *peer_info, int rank); + +int mca_coll_xhc_copy_expose_region(void *base, size_t len, xhc_copy_data_t **region_data); +void mca_coll_xhc_copy_region_post(void *dst, xhc_copy_data_t *region_data); +int mca_coll_xhc_copy_from(xhc_peer_info_t *peer_info, void *dst, + void *src, size_t size, void *access_token); +void mca_coll_xhc_copy_close_region(xhc_copy_data_t *region_data); + +void *mca_coll_xhc_get_registration(xhc_peer_info_t *peer_info, + void *peer_vaddr, size_t size, xhc_reg_t **reg); +void mca_coll_xhc_return_registration(xhc_reg_t *reg); + +// Primitives (respective file) +// ---------------------------- + +int mca_coll_xhc_bcast(void *buf, int count, ompi_datatype_t *datatype, + int root, ompi_communicator_t *comm, mca_coll_base_module_t *module); + +int mca_coll_xhc_barrier(ompi_communicator_t *ompi_comm, + mca_coll_base_module_t *module); + +int mca_coll_xhc_reduce(const void *sbuf, void *rbuf, + int count, ompi_datatype_t *datatype, ompi_op_t *op, int root, + ompi_communicator_t *comm, mca_coll_base_module_t *module); + +int mca_coll_xhc_allreduce(const void *sbuf, void *rbuf, + int count, ompi_datatype_t *datatype, ompi_op_t *op, + ompi_communicator_t *comm, mca_coll_base_module_t *module); + +// Miscellaneous +// ------------- + +#define xhc_allreduce_internal(...) mca_coll_xhc_allreduce_internal(__VA_ARGS__) + +int mca_coll_xhc_allreduce_internal(const void *sbuf, void *rbuf, int count, + ompi_datatype_t *datatype, ompi_op_t *op, ompi_communicator_t *ompi_comm, + mca_coll_base_module_t *module, bool require_bcast); + +// ---------------------------------------- + +// Rollover-safe check that flag has reached/exceeded thresh, with max deviation +static inline bool CHECK_FLAG(volatile xf_sig_t *flag, + xf_sig_t thresh, xf_sig_t win) { + + // This is okay because xf_sig_t is unsigned. Take care. + // The cast's necessity is dependent on the size of xf_sig_t + return ((xf_sig_t) (*flag - thresh) <= win); +} + +static inline void WAIT_FLAG(volatile xf_sig_t *flag, + xf_sig_t thresh, xf_sig_t win) { + bool ready = false; + + do { + for(int i = 0; i < OMPI_XHC_OPAL_PROGRESS_CYCLE; i++) { + if(CHECK_FLAG(flag, thresh, win)) { + ready = true; + break; + } + + /* xf_sig_t f = *flag; + if(CHECK_FLAG(&f, thresh, win)) { + ready = true; + break; + } else if(CHECK_FLAG(&f, thresh, 1000)) + printf("Debug: Flag check with window %d failed, " + "but succeeded with window 1000. flag = %d, " + "thresh = %d\n", win, f, thresh); */ + } + + if(!ready) { + opal_progress(); + } + } while(!ready); +} + +// ---------------------------------------- + +END_C_DECLS + +#endif diff --git a/ompi/mca/coll/xhc/coll_xhc_allreduce.c b/ompi/mca/coll/xhc/coll_xhc_allreduce.c new file mode 100644 index 00000000000..d45065b9dc0 --- /dev/null +++ b/ompi/mca/coll/xhc/coll_xhc_allreduce.c @@ -0,0 +1,1121 @@ +/* + * Copyright (c) 2021-2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" +#include "mpi.h" + +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/communicator/communicator.h" +#include "ompi/op/op.h" + +#include "opal/mca/rcache/base/base.h" +#include "opal/util/show_help.h" +#include "opal/util/minmax.h" + +#include "coll_xhc.h" + +#define MAX_REDUCE_AREAS(comm) \ + ((int)(sizeof((comm)->reduce_area)/sizeof((comm)->reduce_area[0]))) + +OBJ_CLASS_INSTANCE(xhc_rq_item_t, opal_list_item_t, NULL, NULL); + +// ----------------------------- + +/* For the reduction areas, see comments in xhc_reduce_area_t's definition. + * For the leader reduction assistance policies see the flag definitions. */ +static void init_reduce_areas(xhc_comm_t *comms, + int comm_count, int allreduce_count, size_t dtype_size) { + + bool uniform_chunks = mca_coll_xhc_component.uniform_chunks; + int lb_rla = mca_coll_xhc_component.lb_reduce_leader_assist; + + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + int avail_workers[MAX_REDUCE_AREAS(xc)]; + + for(int area_id = 0; area_id < MAX_REDUCE_AREAS(xc); area_id++) { + int workers = xc->size - 1; + + if(lb_rla & OMPI_XHC_LB_RLA_TOP_LEVEL) { + if(i == comm_count - 1 && workers < xc->size) + workers++; + } + + if(lb_rla & OMPI_XHC_LB_RLA_FIRST_CHUNK) { + if(area_id == 0 && workers < xc->size) + workers++; + } + + if(lb_rla & OMPI_XHC_LB_RLA_ALL) { + workers = xc->size; + } + + avail_workers[area_id] = workers; + } + + // Min/max work that a worker may perform (one step) + int min_elems = mca_coll_xhc_component.uniform_chunks_min / dtype_size; + int max_elems = xc->chunk_size / dtype_size; + + int area_id = 0, el_idx = 0; + + while(area_id < MAX_REDUCE_AREAS(xc) && el_idx < allreduce_count) { + xhc_reduce_area_t *area = &xc->reduce_area[area_id]; + + *area = (xhc_reduce_area_t) {0}; + + int remaining = allreduce_count - el_idx; + int workers = avail_workers[area_id]; + + int elems_per_member; + int repeat = 0; + + int area_elems = opal_min(max_elems * workers, remaining); + + /* We should consider the future size of the next area. If it's + * too small in relation to the minimum chunk (min_elems), some + * workers of the next area won't perform work, leading to load + * imbalance. In this case, we elect to either shrink the current + * area so that we will be able to better balance the load in the + * next one, or if the elements that remain for the next area are + * especially few, we make this area absorb the next one. + * Specifically, we absorb it if the increase of each worker's + * load is no more than 10% of the maximum load set. */ + if(uniform_chunks && area_id < MAX_REDUCE_AREAS(xc) - 1) { + int next_workers = avail_workers[area_id+1]; + int next_remaining = allreduce_count - (el_idx + area_elems); + + if(next_remaining < next_workers * min_elems) { + if(next_remaining/workers <= max_elems/10) { + area_elems += next_remaining; + } else { + int ideal_donate = next_workers * min_elems - next_remaining; + + /* Don't donate so much elements that this area + * won't cover its own min reduction chunk size */ + int max_donate = area_elems - workers * min_elems; + max_donate = (max_donate > 0 ? max_donate : 0); + + area_elems -= opal_min(ideal_donate, max_donate); + } + } + } + + if(uniform_chunks) { + /* The elements might not be enough for every worker to do + * work. We calculate how many workers we need so that no + * one of them does less than min_elems work, and use the + * result to calculate the final elements per member. */ + workers = opal_min(area_elems/min_elems, workers); + workers = opal_max(workers, 1); + + elems_per_member = area_elems / workers; + } else { + elems_per_member = max_elems; + workers = area_elems/max_elems; + } + + // If this is the middle area, try to maximize its size + if(area_id == 1 && workers > 0) { + int set = workers * elems_per_member; + repeat = (int)((remaining-area_elems)/set); + area_elems += repeat * set; + } + + area->start = el_idx; + area->len = area_elems; + area->workers = workers; + area->stride = workers * elems_per_member; + + /* My ID, assuming that if some member is not reducing, it is + * the one with ID=0, because currently only member 0 becomes + * the leader, and the leader is the only one that might not + * be reducing. */ + int worker_id = xc->member_id - (xc->size - avail_workers[area_id]); + + area->work_begin = el_idx + worker_id * elems_per_member; + area->work_chunk = (worker_id >= 0 && worker_id < workers ? + elems_per_member : 0); + + area->work_leftover = 0; + + int leftover_elems = (workers > 0 ? + (area_elems % (workers * elems_per_member)) : area_elems); + if(leftover_elems) { + if(worker_id == (uniform_chunks ? workers - 1 : workers)) { + area->work_leftover = leftover_elems; + } + } + + area->work_end = area->work_begin + (repeat * area->stride) + + area->work_chunk + area->work_leftover; + + el_idx += area_elems; + area_id++; + } + + assert(el_idx == allreduce_count); + + xc->n_reduce_areas = area_id; + + // Erase zero-work areas + while(xc->n_reduce_areas > 0 + && xc->reduce_area[xc->n_reduce_areas - 1].work_chunk == 0 + && xc->reduce_area[xc->n_reduce_areas - 1].work_leftover == 0) { + xc->n_reduce_areas--; + } + + /* If not a leader on this comm, nothing + * to do on next ones whatsoever */ + if(!xc->is_coll_leader) { + break; + } + } +} + +static void xhc_allreduce_init_local(xhc_comm_t *comms, int comm_count, + int allreduce_count, size_t dtype_size, xf_sig_t seq) { + + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + xc->is_coll_leader = false; + + for(int m = 0; m < xc->size; m++) { + xc->member_info[m] = (xhc_member_info_t) {0}; + } + + xc->all_joined = false; + } + + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + /* The manager is the leader. Even in the dynamic reduce case, + * there (currently) shouldn't be any real benefit from the + * leader being dynamic in allreduce. */ + if(xc->member_id != 0) { + break; + } + + xc->comm_ctrl->leader_seq = seq; + xc->is_coll_leader = true; + } + + init_reduce_areas(comms, comm_count, allreduce_count, dtype_size); + + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + int initial_count = (xc->n_reduce_areas > 0 ? + xc->reduce_area[0].work_begin : allreduce_count); + + int m = 0; + OPAL_LIST_FOREACH_DECL(item, xc->reduce_queue, xhc_rq_item_t) { + if(m == xc->member_id) { + m++; + } + + *item = (xhc_rq_item_t) {.super = item->super, .member = m++, + .count = initial_count, .area_id = 0}; + } + + if(!xc->is_coll_leader) { + break; + } + } +} + +static void xhc_allreduce_init_comm(xhc_comm_t *comms, int comm_count, + void *rbuf, bool do_cico, int ompi_rank, xf_sig_t seq) { + + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + if(!xc->is_coll_leader) { + break; + } + + WAIT_FLAG(&xc->comm_ctrl->coll_ack, seq - 1, 0); + + /* Because there is a control dependency with the load + * from coll_ack above and the code below, and because + * it is a load-store one (not load-load), I declare + * that a read-memory-barrier is not required here. */ + + xc->comm_ctrl->leader_id = xc->member_id; + xc->comm_ctrl->leader_rank = ompi_rank; + xc->comm_ctrl->data_vaddr = (!do_cico ? rbuf : NULL); + xc->comm_ctrl->bytes_ready = 0; + + xhc_atomic_wmb(); + + xc->comm_ctrl->coll_seq = seq; + } +} + +static void xhc_allreduce_init_member(xhc_comm_t *comms, int comm_count, + xhc_peer_info_t *peer_info, void *sbuf, void *rbuf, int allreduce_count, + bool do_cico, int ompi_rank, xf_sig_t seq) { + + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + /* Essentially the value of reduce area-0's + * work_begin, as set in init_local() */ + int rq_first_count = ((xhc_rq_item_t *) + opal_list_get_first(xc->reduce_queue))->count; + + /* Make sure that the previous owner of my member ctrl (tip: can + * occur with dynamic leadership (or non-zero root!?), when it is + * implemented ^^) is not still using it. Also not that this + * previous owner will set member_ack only after the comm's coll_ack + * is set, so it also guarantees that no other member in the comm is + * accessing the member's flags from a previous collective. */ + WAIT_FLAG(&xc->my_member_ctrl->member_ack, seq - 1, 0); + + xc->my_member_ctrl->reduce_done = rq_first_count; + xc->my_member_ctrl->reduce_ready = (i == 0 && !do_cico ? allreduce_count : 0); + + xc->my_member_ctrl->rank = ompi_rank; + + if(!do_cico) { + xc->my_member_ctrl->sbuf_vaddr = (i == 0 ? sbuf : rbuf); + xc->my_member_ctrl->rbuf_vaddr = (xc->is_coll_leader ? rbuf : NULL); + + xc->my_member_ctrl->cico_id = -1; + + xc->my_member_info->sbuf = (i == 0 ? sbuf : rbuf); + xc->my_member_info->rbuf = rbuf; + } else { + xc->my_member_ctrl->sbuf_vaddr = NULL; + xc->my_member_ctrl->rbuf_vaddr = NULL; + + int cico_id = (i == 0 ? ompi_rank : comms[i-1].manager_rank); + xc->my_member_ctrl->cico_id = cico_id; + + xc->my_member_info->sbuf = xhc_get_cico(peer_info, cico_id); + xc->my_member_info->rbuf = xhc_get_cico(peer_info, ompi_rank); + } + + xhc_atomic_wmb(); + xc->my_member_ctrl->member_seq = seq; + + if(!xc->is_coll_leader) { + break; + } + } +} + +// ----------------------------- + +static int xhc_allreduce_attach_member(xhc_comm_t *xc, int member, + xhc_peer_info_t *peer_info, size_t bytes, bool do_cico, xf_sig_t seq) { + + if(xc->member_info[member].init) { + return 0; + } + + if(!do_cico) { + int member_rank = xc->member_ctrl[member].rank; + + void *sbuf_vaddr = xc->member_ctrl[member].sbuf_vaddr; + void *rbuf_vaddr = xc->member_ctrl[member].rbuf_vaddr; + + xc->member_info[member].sbuf = xhc_get_registration( + &peer_info[member_rank], sbuf_vaddr, bytes, + &xc->member_info[member].sbuf_reg); + + if(xc->member_info[member].sbuf == NULL) { + return -1; + } + + // Leaders will also share their rbuf + if(rbuf_vaddr) { + if(rbuf_vaddr != sbuf_vaddr) { + xc->member_info[member].rbuf = xhc_get_registration( + &peer_info[member_rank], rbuf_vaddr, bytes, + &xc->member_info[member].rbuf_reg); + + if(xc->member_info[member].rbuf == NULL) { + return -1; + } + } else + xc->member_info[member].rbuf = xc->member_info[member].sbuf; + } + } else { + /* Here's the deal with CICO buffers and the comm's manager: In order + * to avoid excessive amounts of attachments, ranks that are + * foreign to a comm only attach to the comm's manager's CICO buffer, + * instead of to every member's. Therefore, members will place their + * final data in the manager's CICO buffer, instead of the leader's + * (even though the leader and the manager actually very often are one + * and the same..). */ + + xc->member_info[member].sbuf = xhc_get_cico(peer_info, + xc->member_ctrl[member].cico_id); + + if(CHECK_FLAG(&xc->comm_ctrl->coll_seq, seq, 0) + && member == xc->comm_ctrl->leader_id) { + xc->member_info[member].rbuf = xhc_get_cico(peer_info, xc->manager_rank); + } + } + + xc->member_info[member].init = true; + + return 0; +} + +static void xhc_allreduce_leader_check_all_joined(xhc_comm_t *xc, xf_sig_t seq) { + for(int m = 0; m < xc->size; m++) { + if(m == xc->member_id) { + continue; + } + + if(!CHECK_FLAG(&xc->member_ctrl[m].member_seq, seq, 0)) { + return; + } + } + + xc->all_joined = true; +} + +static void xhc_allreduce_disconnect_peers(xhc_comm_t *comms, int comm_count) { + xhc_comm_t *xc = comms; + + while(xc && xc->is_coll_leader) { + xc = (xc != &comms[comm_count-1] ? xc + 1 : NULL); + } + + if(xc == NULL) { + return; + } + + xhc_reg_t *reg; + + for(int m = 0; m < xc->size; m++) { + if(m == xc->member_id) { + continue; + } + + if((reg = xc->member_info[m].sbuf_reg)) { + xhc_return_registration(reg); + } + + if((reg = xc->member_info[m].rbuf_reg)) { + xhc_return_registration(reg); + } + } +} + +// ----------------------------- + +static xhc_comm_t *xhc_allreduce_bcast_src_comm(xhc_comm_t *comms, int comm_count) { + xhc_comm_t *s = NULL; + + for(int i = 0; i < comm_count; i++) { + if(!comms[i].is_coll_leader) { + s = &comms[i]; + break; + } + } + + return s; +} + +static void xhc_allreduce_do_ack(xhc_comm_t *comms, int comm_count, xf_sig_t seq) { + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + xc->my_member_ctrl->member_ack = seq; + + if(!xc->is_coll_leader) { + break; + } + + for(int m = 0; m < xc->size; m++) { + if(m == xc->member_id) { + continue; + } + + WAIT_FLAG(&xc->member_ctrl[m].member_ack, seq, OMPI_XHC_ACK_WIN); + } + + xc->comm_ctrl->coll_ack = seq; + } +} + +// ----------------------------- + +static void xhc_allreduce_cico_publish(xhc_comm_t *xc, void *data_src, + xhc_peer_info_t *peer_info, int ompi_rank, int allreduce_count, + size_t dtype_size) { + + int ready = xc->my_member_ctrl->reduce_ready; + + /* The chunk size here is just a means of pipelining the CICO + * publishing, for whichever case this might be necessary in. + * There isn't really any reason to consult reduce areas and + * their chunk sizes here.*/ + int elements = opal_min(xc->chunk_size/dtype_size, allreduce_count - ready); + + void *src = (char *) data_src + ready * dtype_size; + void *dst = (char *) xhc_get_cico(peer_info, ompi_rank) + ready * dtype_size; + + memcpy(dst, src, elements * dtype_size); + xhc_atomic_wmb(); + + volatile xf_int_t *rrp = &xc->my_member_ctrl->reduce_ready; + xhc_atomic_store_int(rrp, ready + elements); +} + +static int xhc_allreduce_reduce_get_next(xhc_comm_t *xc, + xhc_peer_info_t *peer_info, int allreduce_count, + size_t dtype_size, bool do_cico, bool out_of_order_reduce, + xf_sig_t seq, xhc_rq_item_t **item_dst) { + + xhc_rq_item_t *member_item = NULL; + int stalled_member = xc->size; + + /* Iterate the reduce queue, to determine which member's data to reduce, + * and from what index. The reduction queue aids in the implementation of + * the rationale that members that are not ready at some point should be + * temporarily skipped, to prevent stalling in the collective. Reasons + * that a member may not be "ready" are (1) it has not yet joined the + * collective, (2) the necessary data have not yet been produced (eg. + * because the member's children have not finished their reduction on the + * previous communicator) or have not been copied to the CICO buffer. + * However, when floating point data is concerned, skipping members and + * therefore doing certain reductions in non-deterministic order results + * to reproducibility problems. Hence the existence of the "dynamic reduce" + * switch; when enabled, members are skipped when not ready. When disabled, + * members are skipped, but only the data of members with a lower ID that + * the one that has stalled can be reduced (eg. member 2 has stalled, but + * reduction for future chunks of members 0 and 1 (only, not of member 3, + * even if it is ready) will begin instead of completely stalling). The + * reduction queue is sorted according to the reduction progress counter in + * each entry. This helps ensure fully reduced chunks are generated as soon + * as possible, so that leaders can quickly propagate them upwards. */ + OPAL_LIST_FOREACH_DECL(item, xc->reduce_queue, xhc_rq_item_t) { + int member = item->member; + + if(!xc->member_info[member].init + && CHECK_FLAG(&xc->member_ctrl[member].member_seq, seq, 0)) { + + xhc_atomic_rmb(); + + int ret = xhc_allreduce_attach_member(xc, member, peer_info, + allreduce_count * dtype_size, do_cico, seq); + + if(ret != 0) { + return ret; + } + } + + if(xc->member_info[member].init && item->count < allreduce_count) { + xhc_reduce_area_t *area = &xc->reduce_area[item->area_id]; + int elements = area->work_chunk; + + if(item->count + elements + area->work_leftover == area->work_end) { + elements += area->work_leftover; + } + + int self_ready = xc->my_member_ctrl->reduce_ready; + + volatile xf_int_t *rrp = &xc->member_ctrl[member].reduce_ready; + int member_ready = xhc_atomic_load_int(rrp); + + if(self_ready >= item->count + elements + && member_ready >= item->count + elements + && member < stalled_member) { + + member_item = item; + break; + } + } + + if(!out_of_order_reduce) { + stalled_member = opal_min(stalled_member, member); + } + } + + if(member_item) { + opal_list_remove_item(xc->reduce_queue, (opal_list_item_t *) member_item); + } + + *item_dst = member_item; + + return 0; +} + +static void xhc_allreduce_rq_item_analyze(xhc_comm_t *xc, xhc_rq_item_t *item, + bool *first_reduction, bool *last_reduction) { + + *first_reduction = false; + *last_reduction = false; + + if(opal_list_get_size(xc->reduce_queue) == 0) { + *first_reduction = true; + *last_reduction = true; + } else { + xhc_rq_item_t *first_item = (xhc_rq_item_t *) + opal_list_get_first(xc->reduce_queue); + + xhc_rq_item_t *last_item = (xhc_rq_item_t *) + opal_list_get_last(xc->reduce_queue); + + /* If this count is equal or larger than the last one, it means that + * no other count in the queue is larger than it. Therefore, this is the + * first reduction taking place for the "member_item->count" chunk idx. */ + if(item->count >= last_item->count) { + *first_reduction = true; + } + + /* If this count is uniquely minimum in the queue, this is the + * last reduction taking place for this specific chunk index. */ + if(item->count < first_item->count) { + *last_reduction = true; + } + } +} + +static void xhc_allreduce_do_reduce(xhc_comm_t *xc, xhc_rq_item_t *member_item, + int allreduce_count, ompi_datatype_t *dtype, size_t dtype_size, + ompi_op_t *op) { + + xhc_reduce_area_t *area = &xc->reduce_area[member_item->area_id]; + int elements = area->work_chunk; + + if(member_item->count + elements + area->work_leftover == area->work_end) { + elements += area->work_leftover; + } + + size_t offset = member_item->count * dtype_size; + + char *src = (char *) xc->member_info[member_item->member].sbuf + offset; + + char *dst; + char *src2 = NULL; + + bool first_reduction, last_reduction; + + xhc_allreduce_rq_item_analyze(xc, member_item, + &first_reduction, &last_reduction); + + /* Only access comm_ctrl when it's the last reduction. Otherwise, + * it's not guaranteed that the leader will have initialized it yet.*/ + if(last_reduction) { + dst = (char *) xc->member_info[xc->comm_ctrl->leader_id].rbuf + offset; + } else { + dst = (char *) xc->my_member_info->rbuf + offset; + } + + if(first_reduction) { + src2 = (char *) xc->my_member_info->sbuf + offset; + } else if(last_reduction) { + src2 = (char *) xc->my_member_info->rbuf + offset; + } + + // Happens under certain circumstances with MPI_IN_PLACE or with CICO + if(src2 == dst) { + src2 = NULL; + } else if(src == dst) { + src = src2; + src2 = NULL; + } + + xhc_atomic_rmb(); + + if(src2) { + ompi_3buff_op_reduce(op, src2, src, dst, elements, dtype); + } else { + ompi_op_reduce(op, src, dst, elements, dtype); + } + + /* If we reached the end of the area after this reduction, switch + * to the next one, or mark completion if it was the last one. + * Otherwise, adjust the count according to the area's parameters. */ + if(member_item->count + elements == area->work_end) { + if(member_item->area_id < xc->n_reduce_areas - 1) { + member_item->area_id++; + member_item->count = xc->reduce_area[member_item->area_id].work_begin; + } else { + member_item->count = allreduce_count; + } + } else { + member_item->count += area->stride; + } +} + +static void xhc_allreduce_reduce_return_item(xhc_comm_t *xc, + xhc_rq_item_t *member_item) { + + bool placed = false; + + xhc_rq_item_t *item; + OPAL_LIST_FOREACH_REV(item, xc->reduce_queue, xhc_rq_item_t) { + if(member_item->count >= item->count) { + opal_list_insert_pos(xc->reduce_queue, + (opal_list_item_t *) item->super.opal_list_next, + (opal_list_item_t *) member_item); + + placed = true; + break; + } + } + + if(!placed) { + opal_list_prepend(xc->reduce_queue, (opal_list_item_t *) member_item); + } + + xhc_rq_item_t *first_item = (xhc_rq_item_t *) + opal_list_get_first(xc->reduce_queue); + + if(first_item->count > xc->my_member_ctrl->reduce_done) { + xhc_atomic_wmb(); + + volatile xf_int_t *rdp = &xc->my_member_ctrl->reduce_done; + xhc_atomic_store_int(rdp, first_item->count); + } +} + +static void xhc_allreduce_do_bcast(xhc_comm_t *comms, int comm_count, + xhc_comm_t *src_comm, size_t bytes_total, size_t *bcast_done, + const void *bcast_src, void *bcast_dst, void *bcast_cico) { + + size_t copy_size = opal_min(src_comm->chunk_size, bytes_total - *bcast_done); + + volatile xf_size_t *brp = &src_comm->comm_ctrl->bytes_ready; + + if(xhc_atomic_load_size_t(brp) - *bcast_done >= copy_size) { + void *src = (char *) bcast_src + *bcast_done; + void *dst = (char *) bcast_dst + *bcast_done; + void *cico_dst = (char *) bcast_cico + *bcast_done; + + xhc_atomic_rmb(); + + if(bcast_cico && comms[0].is_coll_leader) { + memcpy(cico_dst, src, copy_size); + } else { + memcpy(dst, src, copy_size); + } + + *bcast_done += copy_size; + + xhc_atomic_wmb(); + + for(int i = 0; i < comm_count; i++) { + if(!comms[i].is_coll_leader) { + break; + } + + volatile xf_size_t *brp_d = &comms[i].comm_ctrl->bytes_ready; + xhc_atomic_store_size_t(brp_d, *bcast_done); + } + + if(bcast_cico && comms[0].is_coll_leader) { + memcpy(dst, cico_dst, copy_size); + } + } +} + +// ----------------------------- + +int mca_coll_xhc_allreduce_internal(const void *sbuf, void *rbuf, int count, + ompi_datatype_t *datatype, ompi_op_t *op, ompi_communicator_t *ompi_comm, + mca_coll_base_module_t *ompi_module, bool require_bcast) { + + xhc_module_t *module = (xhc_module_t *) ompi_module; + + if(!module->init) { + int ret = xhc_lazy_init(module, ompi_comm); + if(ret != OMPI_SUCCESS) { + return ret; + } + } + + if(!ompi_datatype_is_predefined(datatype)) { + static bool warn_shown = false; + + if(!warn_shown) { + opal_output_verbose(MCA_BASE_VERBOSE_WARN, + ompi_coll_base_framework.framework_output, + "coll:xhc: Warning: XHC does not currently support " + "derived datatypes; utilizing fallback component"); + warn_shown = true; + } + + xhc_coll_fns_t fallback = module->prev_colls; + + if(require_bcast) { + return fallback.coll_allreduce(sbuf, rbuf, count, datatype, + op, ompi_comm, fallback.coll_allreduce_module); + } else { + return fallback.coll_reduce(sbuf, rbuf, count, datatype, + op, 0, ompi_comm, fallback.coll_reduce_module); + } + } + + if(!ompi_op_is_commute(op)) { + static bool warn_shown = false; + + if(!warn_shown) { + opal_output_verbose(MCA_BASE_VERBOSE_WARN, + ompi_coll_base_framework.framework_output, + "coll:xhc: Warning: (all)reduce does not support non-commutative " + "operators; utilizing fallback component"); + warn_shown = true; + } + + xhc_coll_fns_t fallback = module->prev_colls; + + if(require_bcast) { + return fallback.coll_allreduce(sbuf, rbuf, count, datatype, + op, ompi_comm, fallback.coll_allreduce_module); + } else { + return fallback.coll_reduce(sbuf, rbuf, count, datatype, + op, 0, ompi_comm, fallback.coll_reduce_module); + } + } + + // ---- + + xhc_peer_info_t *peer_info = module->peer_info; + xhc_data_t *data = module->data; + + xhc_comm_t *comms = data->comms; + int comm_count = data->comm_count; + + size_t dtype_size, bytes_total; + ompi_datatype_type_size(datatype, &dtype_size); + bytes_total = count * dtype_size; + + bool do_cico = (bytes_total <= OMPI_XHC_CICO_MAX); + bool out_of_order_reduce = false; + + int rank = ompi_comm_rank(ompi_comm); + + // ---- + + switch(mca_coll_xhc_component.dynamic_reduce) { + case OMPI_XHC_DYNAMIC_REDUCE_DISABLED: + out_of_order_reduce = false; + break; + + case OMPI_XHC_DYNAMIC_REDUCE_NON_FLOAT: + out_of_order_reduce = !(datatype->super.flags & OMPI_DATATYPE_FLAG_DATA_FLOAT); + break; + + case OMPI_XHC_DYNAMIC_REDUCE_ALL: + out_of_order_reduce = true; + break; + } + + // ---- + + // rbuf won't be present for non-root ranks in MPI_Reduce + if(rbuf == NULL && !do_cico) { + if(module->rbuf_size < bytes_total) { + void *tmp = realloc(module->rbuf, bytes_total); + + if(tmp != NULL) { + module->rbuf = tmp; + module->rbuf_size = bytes_total; + } else { + return OPAL_ERR_OUT_OF_RESOURCE; + } + } + + rbuf = module->rbuf; + } + + // ---- + + xf_sig_t pvt_seq = ++data->pvt_coll_seq; + + if(sbuf == MPI_IN_PLACE) { + sbuf = rbuf; + } + + xhc_allreduce_init_local(comms, comm_count, count, dtype_size, pvt_seq); + xhc_allreduce_init_comm(comms, comm_count, rbuf, do_cico, rank, pvt_seq); + xhc_allreduce_init_member(comms, comm_count, peer_info, + (void *) sbuf, rbuf, count, do_cico, rank, pvt_seq); + + void *local_cico = xhc_get_cico(peer_info, comms[0].manager_rank); + + // My conscience is clear! + if(require_bcast) { + goto _allreduce; + } else { + goto _reduce; + } + +// ============================================================================= + +_allreduce: { + + xhc_comm_t *bcast_comm = + xhc_allreduce_bcast_src_comm(comms, comm_count); + + bool bcast_leader_joined = false; + + for(size_t bytes_done = 0; bytes_done < bytes_total; ) { + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + xhc_comm_t *xnc = (i < comm_count - 1 ? &comms[i+1] : NULL); + + if(do_cico && i == 0 && xc->my_member_ctrl->reduce_ready < count) { + xhc_allreduce_cico_publish(xc, (void *) sbuf, + peer_info, rank, count, dtype_size); + } + + if(xc->is_coll_leader) { + int completed = 0; + + if(!xc->all_joined) { + xhc_allreduce_leader_check_all_joined(xc, pvt_seq); + } + + if(xc->all_joined) { + completed = count; + + for(int m = 0; m < xc->size; m++) { + volatile xf_int_t *rdp = &xc->member_ctrl[m].reduce_done; + int member_done = xhc_atomic_load_int(rdp); + + /* Watch out for double evaluation here, don't perform + * sensitive loads inside opal_min()'s parameter list. */ + completed = opal_min(completed, member_done); + } + } + + if(xnc && completed > xnc->my_member_ctrl->reduce_ready) { + volatile xf_int_t *rrp = &xnc->my_member_ctrl->reduce_ready; + xhc_atomic_store_int(rrp, completed); + } else if(!xnc) { + size_t bytes_fully_reduced = completed * dtype_size; + + // Broadcast fully reduced data + if(bytes_fully_reduced > bytes_done) { + for(int k = 0; k < comm_count; k++) { + volatile xf_size_t *brp = + &comms[k].comm_ctrl->bytes_ready; + xhc_atomic_store_size_t(brp, bytes_fully_reduced); + } + + if(do_cico) { + void *src = (char *) local_cico + bytes_done; + void *dst = (char *) rbuf + bytes_done; + memcpy(dst, src, bytes_fully_reduced - bytes_done); + } + + bytes_done = bytes_fully_reduced; + } + } + } + + // Is the reduction phase completed? + if(xc->my_member_ctrl->reduce_done < count) { + xhc_rq_item_t *member_item = NULL; + + int ret = xhc_allreduce_reduce_get_next(xc, + peer_info, count, dtype_size, do_cico, + out_of_order_reduce, pvt_seq, &member_item); + + if(ret != 0) { + return OMPI_ERROR; + } + + if(member_item) { + xhc_allreduce_do_reduce(xc, member_item, + count, datatype, dtype_size, op); + + xhc_allreduce_reduce_return_item(xc, member_item); + } + } + + /* If not a leader in this comm, not + * participating in higher-up ones. */ + if(!xc->is_coll_leader) { + break; + } + } + + if(bcast_comm && !bcast_leader_joined) { + if(CHECK_FLAG(&bcast_comm->comm_ctrl->coll_seq, pvt_seq, 0)) { + xhc_atomic_rmb(); + + int leader = bcast_comm->comm_ctrl->leader_id; + + if(!bcast_comm->member_info[leader].init) { + WAIT_FLAG(&bcast_comm->member_ctrl[leader].member_seq, + pvt_seq, 0); + + xhc_atomic_rmb(); + + xhc_allreduce_attach_member(bcast_comm, leader, + peer_info, bytes_total, do_cico, pvt_seq); + } + + bcast_leader_joined = true; + } + } + + if(bcast_comm && bcast_leader_joined) { + int leader = bcast_comm->comm_ctrl->leader_id; + + xhc_allreduce_do_bcast(comms, comm_count, + bcast_comm, bytes_total, &bytes_done, + bcast_comm->member_info[leader].rbuf, + rbuf, (do_cico ? local_cico : NULL)); + } + } + + xhc_allreduce_do_ack(comms, comm_count, pvt_seq); + + goto _finish; +} + +// ============================================================================= + +_reduce: { + + size_t cico_copied = 0; + int completed_comms = 0; + + while(completed_comms < comm_count) { + for(int i = completed_comms; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + xhc_comm_t *xnc = (i < comm_count - 1 ? &comms[i+1] : NULL); + + if(do_cico && i == 0 && xc->my_member_ctrl->reduce_ready < count) { + xhc_allreduce_cico_publish(xc, (void *) sbuf, + peer_info, rank, count, dtype_size); + } + + if(xc->is_coll_leader) { + int completed = 0; + + if(!xc->all_joined) { + xhc_allreduce_leader_check_all_joined(xc, pvt_seq); + } + + if(xc->all_joined) { + completed = count; + + for(int m = 0; m < xc->size; m++) { + volatile xf_int_t *rdp = &xc->member_ctrl[m].reduce_done; + int member_done = xhc_atomic_load_int(rdp); + + /* Watch out for double evaluation here, don't perform + * sensitive loads inside opal_min()'s parameter list. */ + completed = opal_min(completed, member_done); + } + } + + if(xnc && completed > xnc->my_member_ctrl->reduce_ready) { + volatile xf_int_t *rrp = &xnc->my_member_ctrl->reduce_ready; + xhc_atomic_store_int(rrp, completed); + } else if(!xnc) { + size_t completed_bytes = completed * dtype_size; + + if(do_cico && completed_bytes > cico_copied) { + void *src = (char *) local_cico + cico_copied; + void *dst = (char *) rbuf + cico_copied; + + memcpy(dst, src, completed_bytes - cico_copied); + cico_copied = completed_bytes; + } + } + + if(completed >= count) { + xc->comm_ctrl->coll_ack = pvt_seq; + completed_comms++; + } + } + + // Is the reduction phase completed? + if(xc->my_member_ctrl->reduce_done < count) { + xhc_rq_item_t *member_item = NULL; + + int ret = xhc_allreduce_reduce_get_next(xc, + peer_info, count, dtype_size, do_cico, + out_of_order_reduce, pvt_seq, &member_item); + + if(ret != 0) { + return OMPI_ERROR; + } + + if(member_item) { + xhc_allreduce_do_reduce(xc, member_item, + count, datatype, dtype_size, op); + + xhc_allreduce_reduce_return_item(xc, member_item); + } + } + + if(!xc->is_coll_leader) { + /* If all reduction-related tasks are done, and + * not a leader on the next comm, can exit */ + if(xc->my_member_ctrl->reduce_done >= count + && xc->my_member_ctrl->reduce_ready >= count) { + goto _reduce_done; + } + + /* Not a leader in this comm, so not + * participating in higher-up ones. */ + break; + } + } + } + + _reduce_done: + + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + /* Wait for the leader to give the signal that reduction + * has finished on this comm and members are free to exit */ + if(!xc->is_coll_leader) { + WAIT_FLAG(&xc->comm_ctrl->coll_ack, pvt_seq, OMPI_XHC_ACK_WIN); + } + + // load-store control dependency with coll_ack; no need for barrier + xc->my_member_ctrl->member_ack = pvt_seq; + + if(!xc->is_coll_leader) { + break; + } + } + + goto _finish; +} + +// ============================================================================= + +_finish: + + if(!do_cico) { + xhc_allreduce_disconnect_peers(comms, comm_count); + } + + return OMPI_SUCCESS; +} + +int mca_coll_xhc_allreduce(const void *sbuf, void *rbuf, + int count, ompi_datatype_t *datatype, ompi_op_t *op, + ompi_communicator_t *ompi_comm, mca_coll_base_module_t *ompi_module) { + + return xhc_allreduce_internal(sbuf, rbuf, + count, datatype, op, ompi_comm, ompi_module, true); +} diff --git a/ompi/mca/coll/xhc/coll_xhc_atomic.h b/ompi/mca/coll/xhc/coll_xhc_atomic.h new file mode 100644 index 00000000000..79f1dce98cb --- /dev/null +++ b/ompi/mca/coll/xhc/coll_xhc_atomic.h @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2021-2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef MCA_COLL_XHC_ATOMIC_EXPORT_H +#define MCA_COLL_XHC_ATOMIC_EXPORT_H + +#include +#include "opal/sys/atomic.h" + +// ---------------------------------------- + +#define IS_SIG_ATOMIC_X_BITS(x) \ + (SIG_ATOMIC_MAX == INT ## x ## _MAX) || (SIG_ATOMIC_MAX == UINT ## x ## _MAX) + +// ---------------------------------------- + +// If xf_sig_t is ever re-defined to be signed, + // CHECK_FLAGS()'s comparisons must be adjusted +#if IS_SIG_ATOMIC_X_BITS(64) + typedef uint64_t xf_sig_t; +#elif IS_SIG_ATOMIC_X_BITS(32) + typedef uint32_t xf_sig_t; +#elif IS_SIG_ATOMIC_X_BITS(16) + typedef uint16_t xf_sig_t; +#elif IS_SIG_ATOMIC_X_BITS(8) + typedef uint8_t xf_sig_t; +#endif + +typedef int __attribute__((aligned(SIZEOF_INT))) xf_int_t; +typedef size_t __attribute__((aligned(SIZEOF_SIZE_T))) xf_size_t; + +// ---------------------------------------- + +#define xhc_atomic_rmb opal_atomic_rmb +#define xhc_atomic_wmb opal_atomic_wmb +#define xhc_atomic_fmb opal_atomic_mb + +// https://github.com/open-mpi/ompi/issues/9722 + +#if OPAL_USE_GCC_BUILTIN_ATOMICS || OPAL_USE_C11_ATOMICS + #define xhc_atomic_load_int(addr) __atomic_load_n(addr, __ATOMIC_RELAXED) + #define xhc_atomic_store_int(addr, val) __atomic_store_n(addr, val, __ATOMIC_RELAXED) + + #define xhc_atomic_load_size_t(addr) __atomic_load_n(addr, __ATOMIC_RELAXED) + #define xhc_atomic_store_size_t(addr, val) __atomic_store_n(addr, val, __ATOMIC_RELAXED) +#else + #define xhc_atomic_load_int(addr) (*(addr)) + #define xhc_atomic_store_int(addr, val) (*(addr) = (val)) + + #define xhc_atomic_load_size_t(addr) (*(addr)) + #define xhc_atomic_store_size_t(addr, val) (*(addr) = (val)) + + #warning "GCC or the C11 atomics backend was not found. XHC might not function correctly" +/* #else + #error "XHC atomics do not yet work without the GCC or the C11 backend" */ +#endif + + +// If/when opal atomic load/store size_t is added + +/* #define xhc_atomic_load_size_t(addr) \ + opal_atomic_load_size_t ((opal_atomic_size_t *) addr) +#define xhc_atomic_store_size_t(addr, val) \ + opal_atomic_store_size_t ((opal_atomic_size_t *) addr, val) */ + + +// If/when opal atomic load/store is added, and if opal atomic load/store int is not + +/* #if SIZEOF_INT == 4 + #define xhc_atomic_load_int(addr) opal_atomic_load_32 ((opal_atomic_int32_t *) addr) + #define xhc_atomic_store_int(addr, val) opal_atomic_store_32 ((opal_atomic_int32_t *) addr, val) +#elif SIZEOF_INT == 8 + #define xhc_atomic_load_int(addr) opal_atomic_load_64 ((opal_atomic_int64_t *) addr) + #define xhc_atomic_store_int(addr, val) opal_atomic_store_64 ((opal_atomic_int64_t *) addr, val) +#else + #error "Unsupported int size" +#endif */ + + +// If/when opal atomic load/store is added, and if opal atomic load/store size_t is not + +/* #if SIZEOF_SIZE_T == 4 + #define xhc_atomic_load_size_t(addr) opal_atomic_load_32 ((opal_atomic_int32_t *) addr) + #define xhc_atomic_store_size_t(addr, val) opal_atomic_store_32 ((opal_atomic_int32_t *) addr, val) +#elif SIZEOF_SIZE_T == 8 + #define xhc_atomic_load_size_t(addr) opal_atomic_load_64 ((opal_atomic_int64_t *) addr) + #define xhc_atomic_store_size_t(addr, val) opal_atomic_store_64 ((opal_atomic_int64_t *) addr, val) +#else + #error "Unsupported size_t size" +#endif */ + +static inline bool xhc_atomic_cmpxchg_strong_relaxed(volatile xf_sig_t *addr, + xf_sig_t *oldval, xf_sig_t newval) { + + #if OPAL_USE_GCC_BUILTIN_ATOMICS || OPAL_USE_C11_ATOMICS + return __atomic_compare_exchange_n(addr, oldval, newval, + false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); + #else + #if IS_SIG_ATOMIC_X_BITS(32) + return opal_atomic_compare_exchange_strong_32(addr, oldval, newval); + #elif IS_SIG_ATOMIC_X_BITS(64) + return opal_atomic_compare_exchange_strong_64(addr, oldval, newval); + #else + #error "Unsupported sig_atomic_t size" + #endif + #endif +} + +#endif diff --git a/ompi/mca/coll/xhc/coll_xhc_barrier.c b/ompi/mca/coll/xhc/coll_xhc_barrier.c new file mode 100644 index 00000000000..ade1300134a --- /dev/null +++ b/ompi/mca/coll/xhc/coll_xhc_barrier.c @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2021-2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" +#include "mpi.h" + +#include "ompi/constants.h" +#include "ompi/communicator/communicator.h" + +#include "coll_xhc.h" + +static void xhc_barrier_leader(xhc_comm_t *comms, int comm_count, + xhc_peer_info_t *peer_info, int rank, int root, xf_sig_t seq) { + + // Non-leader by default + for(int i = 0; i < comm_count; i++) { + comms[i].is_coll_leader = false; + } + + for(int i = 0; i < comm_count; i++) { + // I'm the root and therefore always a leader + if(rank == root) { + comms[i].comm_ctrl->leader_seq = seq; + comms[i].is_coll_leader = true; + + continue; + } + + // The root takes leadership precedence when local + if(PEER_IS_LOCAL(peer_info, root, comms[i].locality)) { + break; + } + + // The member with the lowest ID (ie. the manager) becomes the leader + if(comms[i].member_id == 0) { + comms[i].comm_ctrl->leader_seq = seq; + comms[i].is_coll_leader = true; + } + + // Non-leaders exit; they can't become leaders on higher levels + if(comms[i].is_coll_leader == false) { + break; + } + } +} + +/* Hierarchical Barrier with seq/ack flags + * --------------------------------------- + * 1. Ranks write their coll_seq field to signal they have joined + * the collective. Leaders propagate this information towards + * the top-most comm's leader using the same method. + * + * 2. The top-most comm's leader (root) sets the comm's coll_ack + * field to signal, that all ranks have joined the barrier. + * + * 3. Leaders propagate the info towards the bottom-most comm, using + * the same method. Ranks wait on thei coll_ack flag, set their + * own ack, and exit the collective. + * --------------------------------------- */ +int mca_coll_xhc_barrier(ompi_communicator_t *ompi_comm, + mca_coll_base_module_t *ompi_module) { + + xhc_module_t *module = (xhc_module_t *) ompi_module; + + if(!module->init) { + int ret = xhc_lazy_init(module, ompi_comm); + if(ret != OMPI_SUCCESS) return ret; + } + + xhc_peer_info_t *peer_info = module->peer_info; + xhc_data_t *data = module->data; + + xhc_comm_t *comms = data->comms; + int comm_count = data->comm_count; + + int rank = ompi_comm_rank(ompi_comm); + + xf_sig_t pvt_seq = ++data->pvt_coll_seq; + + xhc_barrier_leader(comms, comm_count, peer_info, rank, + mca_coll_xhc_component.barrier_root, pvt_seq); + + // 1. Upwards SEQ Wave + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + xc->my_member_ctrl->member_seq = pvt_seq; + + if(!xc->is_coll_leader) { + break; + } + + for(int m = 0; m < xc->size; m++) { + if(m == xc->member_id) { + continue; + } + + /* Poll comm members and wait for them to join the barrier. + * No need for windowed comparison here; Ranks won't exit the + * barrier before the leader has set the coll_seq flag. */ + WAIT_FLAG(&xc->member_ctrl[m].member_seq, pvt_seq, 0); + } + } + + // 2. Wait for ACK (root won't wait!) + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + if(xc->is_coll_leader == false) { + WAIT_FLAG(&xc->comm_ctrl->coll_ack, pvt_seq, 0); + break; + } + } + + // 3. Trigger ACK Wave + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + /* Not actually necessary for the barrier operation, but + * good for consistency between all seq/ack numbers */ + xc->my_member_ctrl->member_ack = pvt_seq; + + if(!xc->is_coll_leader) { + break; + } + + xc->comm_ctrl->coll_ack = pvt_seq; + } + + return OMPI_SUCCESS; +} diff --git a/ompi/mca/coll/xhc/coll_xhc_bcast.c b/ompi/mca/coll/xhc/coll_xhc_bcast.c new file mode 100644 index 00000000000..f0b99983e50 --- /dev/null +++ b/ompi/mca/coll/xhc/coll_xhc_bcast.c @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2021-2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" +#include "mpi.h" + +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/communicator/communicator.h" +#include "opal/util/show_help.h" +#include "opal/util/minmax.h" + +#include "coll_xhc.h" + +/* When dynamic leadership is enabled, the first rank of each + * xhc comm to join the collective will become its leader */ +static void xhc_bcast_try_leader(xhc_comm_t *comms, int comm_count, + xhc_peer_info_t *peer_info, int rank, int root, xf_sig_t seq) { + + // Non-leader by default + for(int i = 0; i < comm_count; i++) { + comms[i].is_coll_leader = false; + } + + for(int i = 0; i < comm_count; i++) { + // I'm the root and therefore always a leader + if(rank == root) { + comms[i].comm_ctrl->leader_seq = seq; + comms[i].is_coll_leader = true; + + continue; + } + + // The root takes leadership precedence when local + if(PEER_IS_LOCAL(peer_info, root, comms[i].locality)) { + break; + } + + if(mca_coll_xhc_component.dynamic_leader == false) { + /* If dynamic leadership is disabled, the member with + * the lowest ID (ie. the manager) becomes the leader */ + if(comms[i].member_id == 0) { + comms[i].comm_ctrl->leader_seq = seq; + comms[i].is_coll_leader = true; + } + } else { + // An opportunity exists to become the leader + if(comms[i].comm_ctrl->leader_seq != seq) { + xf_sig_t oldval = seq - 1; + + comms[i].is_coll_leader = xhc_atomic_cmpxchg_strong_relaxed( + &comms[i].comm_ctrl->leader_seq, &oldval, seq); + } + } + + // Non-leaders exit; they can't become leaders on higher levels + if(comms[i].is_coll_leader == false) { + break; + } + } + + /* The writes and the cmpxchg to comm_ctrl->leader_seq, are relaxed. + * They do not synchronize access to any other data, and it's not a + * problem if some closeby loads/stores are reordered with it. The + * only purpose of leader_seq is to determine if a rank will be leader + * or not. Only the result of the cmp operation is utilized. */ +} + +static void xhc_bcast_children_init(xhc_comm_t *comms, int comm_count, + void *buffer, size_t bytes_ready, xhc_copy_data_t *region_data, + bool do_cico, int rank, xf_sig_t seq) { + + for(int i = comm_count - 1; i >= 0; i--) { + xhc_comm_t *xc = &comms[i]; + + if(!xc->is_coll_leader) { + continue; + } + + WAIT_FLAG(&xc->comm_ctrl->coll_ack, seq - 1, 0); + + /* Because there is a control dependency with the loads + * from coll_ack above and the code below, and because it + * is a load-store one (not load-load), I declare that a + * read-memory-barrier is not required here. */ + + xc->comm_ctrl->leader_id = xc->member_id; + xc->comm_ctrl->leader_rank = rank; + + xc->comm_ctrl->cico_id = (do_cico ? comms[0].manager_rank : -1); + + xc->comm_ctrl->data_vaddr = (!do_cico ? buffer : NULL); + xc->comm_ctrl->bytes_ready = bytes_ready; + + if(region_data != NULL) { + xhc_copy_region_post(xc->comm_ctrl->access_token, region_data); + } + + /* The above comm_ctrl stores must have finished before the + * peers are notified to attach/copy. We don't need an atomic + * store to bytes_ready here, since it is guarded by coll_seq. */ + xhc_atomic_wmb(); + + xc->comm_ctrl->coll_seq = seq; + } +} + +static void xhc_bcast_children_set_bytes_ready(xhc_comm_t *comms, + int comm_count, size_t bytes) { + + for(int i = comm_count - 1; i >= 0; i--) { + xhc_comm_t *xc = &comms[i]; + + if(!xc->is_coll_leader) { + continue; + } + + volatile xf_size_t *brp = &xc->comm_ctrl->bytes_ready; + xhc_atomic_store_size_t(brp, bytes); + } + + /* Not much reason for a wmb() here or inside the loop. + * The stores may be reordered after any following stores, + * and within themselves. */ +} + +static void xhc_bcast_do_ack(xhc_comm_t *comms, + int comm_count, xf_sig_t seq) { + + // Set Ack(s) + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + xc->my_member_ctrl->member_ack = seq; + + if(!xc->is_coll_leader) { + break; + } + } + + // Gather members' Ack(s) and set coll_ack + for(int i = 0; i < comm_count; i++) { + xhc_comm_t *xc = &comms[i]; + + if(!xc->is_coll_leader) { + break; + } + + for(int m = 0; m < xc->size; m++) { + if(m == xc->member_id) { + continue; + } + + WAIT_FLAG(&xc->member_ctrl[m].member_ack, seq, OMPI_XHC_ACK_WIN); + } + + xc->comm_ctrl->coll_ack = seq; + } +} + +static xhc_comm_t *xhc_bcast_src_comm(xhc_comm_t *comms, int comm_count) { + xhc_comm_t *s = NULL; + + for(int i = 0; i < comm_count; i++) { + if(!comms[i].is_coll_leader) { + s = &comms[i]; + break; + } + } + + return s; +} + +int mca_coll_xhc_bcast(void *buf, int count, ompi_datatype_t *datatype, int root, + ompi_communicator_t *ompi_comm, mca_coll_base_module_t *ompi_module) { + + xhc_module_t *module = (xhc_module_t *) ompi_module; + + if(!module->init) { + int ret = xhc_lazy_init(module, ompi_comm); + if(ret != OMPI_SUCCESS) return ret; + } + + if(!ompi_datatype_is_predefined(datatype)) { + static bool warn_shown = false; + + if(!warn_shown) { + opal_output_verbose(MCA_BASE_VERBOSE_WARN, + ompi_coll_base_framework.framework_output, + "coll:xhc: Warning: XHC does not currently support " + "derived datatypes; utilizing fallback component"); + warn_shown = true; + } + + xhc_coll_fns_t fallback = ((xhc_module_t *) module)->prev_colls; + return fallback.coll_bcast(buf, count, datatype, root, + ompi_comm, fallback.coll_bcast_module); + } + + // ---- + + xhc_peer_info_t *peer_info = module->peer_info; + xhc_data_t *data = module->data; + + xhc_comm_t *comms = data->comms; + int comm_count = data->comm_count; + + size_t dtype_size, bytes_total; + ompi_datatype_type_size(datatype, &dtype_size); + bytes_total = count * dtype_size; + + int rank = ompi_comm_rank(ompi_comm); + + bool do_cico = (bytes_total <= OMPI_XHC_CICO_MAX); + void *local_cico = xhc_get_cico(peer_info, comms[0].manager_rank); + void *src_buffer; + + // Only really necessary for smsc/knem + xhc_copy_data_t *region_data = NULL; + + // ---- + + xf_sig_t pvt_seq = ++data->pvt_coll_seq; + + xhc_bcast_try_leader(comms, comm_count, peer_info, rank, root, pvt_seq); + + // No chunking for now... TODO? + if(rank == root && do_cico) { + memcpy(local_cico, buf, bytes_total); + } + + if(!do_cico) { + int err = xhc_copy_expose_region(buf, bytes_total, ®ion_data); + if(err != 0) { + return OMPI_ERROR; + } + } + + xhc_bcast_children_init(comms, comm_count, buf, + (rank == root ? bytes_total : 0), region_data, do_cico, rank, pvt_seq); + + if(rank == root) { + goto coll_finish; + } + + // ---- + + /* Not actually necessary for the broadcast operation, but + * good for consistency between all seq/ack numbers */ + for(int i = 0; i < comm_count; i++) { + comms[i].my_member_ctrl->member_seq = pvt_seq; + if(!comms[i].is_coll_leader) { + break; + } + } + + xhc_comm_t *src_comm = xhc_bcast_src_comm(comms, comm_count); + xhc_comm_ctrl_t *src_ctrl = src_comm->comm_ctrl; + + WAIT_FLAG(&src_ctrl->coll_seq, pvt_seq, 0); + xhc_atomic_rmb(); + + if(!do_cico) { + src_buffer = src_ctrl->data_vaddr; + } else { + src_buffer = xhc_get_cico(peer_info, src_ctrl->cico_id); + if(src_buffer == NULL) return OMPI_ERR_OUT_OF_RESOURCE; + } + + size_t bytes_done = 0; + size_t bytes_available = 0; + + while(bytes_done < bytes_total) { + size_t copy_size = opal_min(src_comm->chunk_size, bytes_total - bytes_done); + + void *data_dst = (char *) buf + bytes_done; + void *data_src = (char *) src_buffer + bytes_done; + void *data_cico_dst = (char *) local_cico + bytes_done; + + if(bytes_available < copy_size) { + do { + volatile xf_size_t *brp = &src_ctrl->bytes_ready; + bytes_available = xhc_atomic_load_size_t(brp) - bytes_done; + } while(bytes_available < copy_size); + + // Wait on loads inside the loop + xhc_atomic_rmb(); + } + + /* Pipelining is not necessary on the bottom + * level, copy all available at once */ + if(!comms[0].is_coll_leader) { + copy_size = bytes_available; + } + + if(!do_cico) { + int err = xhc_copy_from(&peer_info[src_ctrl->leader_rank], + data_dst, data_src, copy_size, src_ctrl->access_token); + if(err != 0) { + return OMPI_ERROR; + } + } else { + memcpy((comms[0].is_coll_leader + ? data_cico_dst : data_dst), data_src, copy_size); + } + + bytes_done += copy_size; + bytes_available -= copy_size; + + /* Do make sure the memcpy has completed before + * writing to the peers' bytes_ready. */ + xhc_atomic_wmb(); + + xhc_bcast_children_set_bytes_ready(comms, comm_count, bytes_done); + + if(do_cico && comms[0].is_coll_leader) { + memcpy(data_dst, data_cico_dst, copy_size); + } + } + + if(!do_cico) { + xhc_copy_close_region(region_data); + } + + coll_finish: + + /* No wmb() necessary before sending ACK, as all operations + * that should be waited on (reads from shared buffers) have + * explicit barriers following them. */ + + xhc_bcast_do_ack(comms, comm_count, pvt_seq); + + return OMPI_SUCCESS; +} diff --git a/ompi/mca/coll/xhc/coll_xhc_component.c b/ompi/mca/coll/xhc/coll_xhc_component.c new file mode 100644 index 00000000000..dac4fd3db2d --- /dev/null +++ b/ompi/mca/coll/xhc/coll_xhc_component.c @@ -0,0 +1,677 @@ +/* + * Copyright (c) 2021-2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" +#include "mpi.h" + +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/coll/base/base.h" + +#include "opal/mca/shmem/base/base.h" +#include "opal/util/show_help.h" + +#include "coll_xhc.h" + +typedef int (*csv_parse_conv_fn_t)(char *str, void *dst); +typedef void (*csv_parse_destruct_fn_t)(void *data); + +static int xhc_register(void); + +const char *mca_coll_xhc_component_version_string = + "Open MPI xhc collective MCA component version " OMPI_VERSION; + +static const char *hwloc_topo_str[] = { + "node", "flat", + "socket", + "numa", + "l3", "l3cache", + "l2", "l2cache", + "l1", "l1cache", + "core", + "hwthread", "thread" +}; + +static const xhc_loc_t hwloc_topo_val[] = { + OPAL_PROC_ON_NODE, OPAL_PROC_ON_NODE, + OPAL_PROC_ON_SOCKET, + OPAL_PROC_ON_NUMA, + OPAL_PROC_ON_L3CACHE, OPAL_PROC_ON_L3CACHE, + OPAL_PROC_ON_L2CACHE, OPAL_PROC_ON_L2CACHE, + OPAL_PROC_ON_L1CACHE, OPAL_PROC_ON_L1CACHE, + OPAL_PROC_ON_CORE, + OPAL_PROC_ON_HWTHREAD, OPAL_PROC_ON_HWTHREAD +}; + +mca_coll_xhc_component_t mca_coll_xhc_component = { + .super = { + .collm_version = { + MCA_COLL_BASE_VERSION_2_4_0, + + .mca_component_name = "xhc", + MCA_BASE_MAKE_VERSION(component, OMPI_MAJOR_VERSION, + OMPI_MINOR_VERSION, OMPI_RELEASE_VERSION), + + .mca_register_component_params = xhc_register, + }, + + .collm_data = { + MCA_BASE_METADATA_PARAM_CHECKPOINT + }, + + .collm_init_query = mca_coll_xhc_component_init_query, + .collm_comm_query = mca_coll_xhc_module_comm_query, + }, + + .priority = 0, + .print_info = false, + + .shmem_backing = NULL, + + .dynamic_leader = false, + + .barrier_root = 0, + + .dynamic_reduce = OMPI_XHC_DYNAMIC_REDUCE_NON_FLOAT, + .lb_reduce_leader_assist = + (OMPI_XHC_LB_RLA_TOP_LEVEL | OMPI_XHC_LB_RLA_FIRST_CHUNK), + + .force_reduce = false, + + .cico_max = 1024, + + .uniform_chunks = true, + .uniform_chunks_min = 1024, + + /* These are the parameters that will need + * processing, and their default values. */ + .hierarchy_mca = "numa,socket", + .chunk_size_mca = "16K" +}; + +/* Initial query function that is invoked during MPI_INIT, allowing + * this component to disqualify itself if it doesn't support the + * required level of thread support. */ +int mca_coll_xhc_component_init_query(bool enable_progress_threads, + bool enable_mpi_threads) { + + return OMPI_SUCCESS; +} + +static mca_base_var_enum_value_t dynamic_reduce_options[] = { + {OMPI_XHC_DYNAMIC_REDUCE_DISABLED, "disabled"}, + {OMPI_XHC_DYNAMIC_REDUCE_NON_FLOAT, "non-float"}, + {OMPI_XHC_DYNAMIC_REDUCE_ALL, "all"}, + {0, NULL} +}; + +static mca_base_var_enum_value_flag_t lb_reduce_leader_assist_options[] = { + {OMPI_XHC_LB_RLA_TOP_LEVEL, "top", OMPI_XHC_LB_RLA_ALL}, + {OMPI_XHC_LB_RLA_FIRST_CHUNK, "first", OMPI_XHC_LB_RLA_ALL}, + {OMPI_XHC_LB_RLA_ALL, "all", + (OMPI_XHC_LB_RLA_TOP_LEVEL | OMPI_XHC_LB_RLA_FIRST_CHUNK)}, + {0, NULL, 0} +}; + +static int xhc_register(void) { + mca_base_var_enum_t *var_enum; + mca_base_var_enum_flag_t *var_enum_flag; + char *tmp, *desc; + int ret; + + /* Priority */ + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "priority", "Priority of the xhc component", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_2, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_xhc_component.priority); + + /* Info */ + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "print_info", "Print information during initialization", + MCA_BASE_VAR_TYPE_BOOL, NULL, 0, 0, OPAL_INFO_LVL_3, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_xhc_component.print_info); + + /* SHM Backing dir */ + + mca_coll_xhc_component.shmem_backing = (access("/dev/shm", W_OK) == 0 ? + "/dev/shm" : opal_process_info.job_session_dir); + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "shmem_backing", "Directory to place backing files for shared-memory" + " control-data communication", MCA_BASE_VAR_TYPE_STRING, NULL, 0, 0, + OPAL_INFO_LVL_3, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_xhc_component.shmem_backing); + + /* Dynamic leader */ + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "dynamic_leader", "Enable dynamic operation-wise group-leader selection", + MCA_BASE_VAR_TYPE_BOOL, NULL, 0, 0, OPAL_INFO_LVL_5, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_xhc_component.dynamic_leader); + + /* Dynamic reduce */ + + ret = mca_base_var_enum_create("coll_xhc_dynamic_reduce_options", + dynamic_reduce_options, &var_enum); + if(ret != OPAL_SUCCESS) { + return ret; + } + + /* Barrier root */ + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "barrier_root", "Internal root for the barrier operation (rank ID)", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_5, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_xhc_component.barrier_root); + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "dynamic_reduce", "Dynamic/out-of-order intra-group reduction", + MCA_BASE_VAR_TYPE_INT, var_enum, 0, 0, OPAL_INFO_LVL_6, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_xhc_component.dynamic_reduce); + + OBJ_RELEASE(var_enum); + + /* Load balancing: Reduce leader assistance */ + + ret = mca_base_var_enum_create_flag("coll_xhc_lb_reduce_leader_assist", + lb_reduce_leader_assist_options, &var_enum_flag); + if(ret != OPAL_SUCCESS) { + return ret; + } + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "lb_reduce_leader_assist", "Reduction leader assistance modes for load balancing", + MCA_BASE_VAR_TYPE_INT, &var_enum_flag->super, 0, 0, OPAL_INFO_LVL_6, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_xhc_component.lb_reduce_leader_assist); + + OBJ_RELEASE(var_enum_flag); + + /* Force enable "hacky" reduce */ + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "force_reduce", "Force enable the \"special\" Reduce for all calls", + MCA_BASE_VAR_TYPE_BOOL, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_xhc_component.force_reduce); + + /* Hierarchy features */ + + desc = NULL; + + for(size_t i = 0; i < sizeof(hwloc_topo_str)/sizeof(char *); i++) { + ret = opal_asprintf(&tmp, "%s%s%s", (i > 0 ? desc : ""), + (i > 0 ? ", " : ""), hwloc_topo_str[i]); + free(desc); desc = tmp; + if(ret < 0) { + return OPAL_ERR_OUT_OF_RESOURCE; + } + } + + ret = opal_asprintf(&tmp, "Comma-separated list of topology features to " + "consider for the hierarchy (%s)", desc); + free(desc); desc = tmp; + if(ret < 0) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "hierarchy", desc, MCA_BASE_VAR_TYPE_STRING, NULL, 0, 0, OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_xhc_component.hierarchy_mca); + + free(desc); + + /* Chunk size(s) */ + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "chunk_size", "The chunk size(s) to be used for the pipeline " + "(single value, or comma separated list for different hierarchy levels " + "(bottom to top))", + MCA_BASE_VAR_TYPE_STRING, NULL, 0, 0, OPAL_INFO_LVL_5, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_xhc_component.chunk_size_mca); + + /* Allreduce uniform chunks */ + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "uniform_chunks", "Automatically optimize chunk size in reduction " + "collectives according to message size, for load balancing", + MCA_BASE_VAR_TYPE_BOOL, NULL, 0, 0, OPAL_INFO_LVL_5, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_xhc_component.uniform_chunks); + + /* Allreduce uniform chunks min size */ + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "uniform_chunks_min", "Minimum chunk size for reduction collectives, " + "when \"uniform chunks\" are enabled", MCA_BASE_VAR_TYPE_SIZE_T, + NULL, 0, 0, OPAL_INFO_LVL_5, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_xhc_component.uniform_chunks_min); + + /* CICO threshold (inclusive) */ + + (void) mca_base_component_var_register(&mca_coll_xhc_component.super.collm_version, + "cico_max", "Maximum message size up to which to use CICO", + MCA_BASE_VAR_TYPE_SIZE_T, NULL, 0, 0, OPAL_INFO_LVL_5, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_xhc_component.cico_max); + + return OMPI_SUCCESS; +} + +static int parse_csv(const char *csv_orig, char sep, char ignore_start, + char ignore_end, void **vals_dst, int *len_dst, size_t type_size, + csv_parse_conv_fn_t conv_fn, csv_parse_destruct_fn_t destructor_fn, + char *err_help_header) { + + if(csv_orig == NULL || strlen(csv_orig) == 0) { + *vals_dst = NULL; + *len_dst = 0; + return OMPI_SUCCESS; + } + + char *csv = NULL; + void *vals = NULL; + + int vals_size = 0; + int ntokens = 0; + + int return_code = OMPI_SUCCESS; + + if(!(csv = strdup(csv_orig))) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_OUT_OF_RESOURCE, end); + } + + if(!(vals = malloc((vals_size = 5) * type_size))) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_OUT_OF_RESOURCE, end); + } + + int ignore_cnt = 0; + char *token = csv; + + int csv_len = strlen(csv); + + for(int i = 0; i < csv_len + 1; i++) { + char *c = csv+i; + + if(ntokens == vals_size) { + void *tmp = realloc(vals, (vals_size *= 2) * sizeof(type_size)); + if(!tmp) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_OUT_OF_RESOURCE, end); + } + vals = tmp; + } + + if(ignore_start != 0) { + if(*c == ignore_start) { + ignore_cnt++; + } else if(*c == ignore_end) { + ignore_cnt--; + } + + if(ignore_cnt < 0) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_BAD_PARAM, end); + } + } + + if(ignore_cnt == 0 && (*c == sep || *c == '\0')) { + char oldc = *c; + *c = '\0'; + + int status = conv_fn(token, (char *) vals + ntokens*type_size); + + if(status != OMPI_SUCCESS) { + if(err_help_header) { + opal_show_help("help-coll-xhc.txt", + err_help_header, true, token, csv_orig); + } + + RETURN_WITH_ERROR(return_code, status, end); + } + + ntokens++; + + *c = oldc; + token = c + 1; + } + } + + *vals_dst = vals; + *len_dst = ntokens; + + end: + + free(csv); + + if(return_code != OMPI_SUCCESS) { + if(vals && destructor_fn) { + for(int i = 0; i < ntokens; i++) { + destructor_fn((char *) vals + i*type_size); + } + } + + free(vals); + } + + return return_code; +} + +static int conv_xhc_loc_def_rank_list(char *str, void *result) { + char *strs[2] = {str, NULL}; + int nums[2] = {-1, -1}; + + char *range_op_pos = NULL; + + int return_code = OMPI_SUCCESS; + + if((range_op_pos = strstr(str, ".."))) { + strs[1] = range_op_pos + 2; + *range_op_pos = '\0'; + } + + for(int i = 0; i < 2 && strs[i]; i++) { + char *endptr; + + nums[i] = strtol(strs[i], &endptr, 10); + + if(endptr[0] != '\0' || nums[i] < 0) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_BAD_PARAM, end); + } + } + + ((xhc_rank_range_t *) result)->start_rank = nums[0]; + ((xhc_rank_range_t *) result)->end_rank = (nums[1] != -1 ? nums[1] : nums[0]); + + end: + + if(range_op_pos) { + *range_op_pos = '.'; + } + + return return_code; +} + +static void mca_coll_xhc_loc_def_construct(xhc_loc_def_t *def) { + def->named_loc = 0; + def->rank_list = NULL; + def->rank_list_len = 0; + def->split = 0; + def->max_ranks = 0; + def->repeat = false; +} + +static void mca_coll_xhc_loc_def_destruct(xhc_loc_def_t *def) { + free(def->rank_list); +} + +OBJ_CLASS_INSTANCE(xhc_loc_def_t, opal_list_item_t, + mca_coll_xhc_loc_def_construct, mca_coll_xhc_loc_def_destruct); + +static int conv_xhc_loc_def(char *str, void *result) { + int return_code = OMPI_SUCCESS; + + char *s = strdup(str); + xhc_loc_def_t *def = OBJ_NEW(xhc_loc_def_t); + + if(!s || !def) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_OUT_OF_RESOURCE, end); + } + + /* Parse modifiers and remove them from string */ + + if(s[strlen(s) - 1] == '*') { + def->repeat = true; + s[strlen(s) - 1] = '\0'; + } + + char *colon_pos = strrchr(s, ':'); + char *qmark_pos = strrchr(s, '?'); + + if(colon_pos && qmark_pos) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_BAD_PARAM, end); + } else if(colon_pos || qmark_pos) { + char *numstr = (colon_pos ? colon_pos : qmark_pos); + char *endptr; + + int num = strtol(numstr + 1, &endptr, 10); + + if(endptr[0] != '\0' || num <= 0) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_BAD_PARAM, end); + } + + if(colon_pos) { + def->split = num; + } else { + def->max_ranks = num; + } + + *numstr = '\0'; + } + + /* Parse locality definition */ + + if(s[0] == '[') { + if(def->repeat) { // repeat only makes sense with named localities + RETURN_WITH_ERROR(return_code, OMPI_ERR_BAD_PARAM, end); + } + + s[strlen(s) - 1] = '\0'; + + int status = parse_csv(s+1, ',', 0, 0, (void **) &def->rank_list, + &def->rank_list_len, sizeof(xhc_rank_range_t), + conv_xhc_loc_def_rank_list, NULL, NULL); + + if(status != OMPI_SUCCESS) { + RETURN_WITH_ERROR(return_code, status, end); + } + } else { + bool found = false; + + for(size_t i = 0; i < sizeof(hwloc_topo_str)/sizeof(char *); i++) { + if(strcasecmp(s, hwloc_topo_str[i]) == 0) { + def->named_loc = hwloc_topo_val[i]; + found = true; + break; + } + } + + if(!found) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_BAD_PARAM, end); + } + } + + * (xhc_loc_def_t **) result = def; + + end: + + free(s); + + if(return_code != OMPI_SUCCESS) { + OBJ_RELEASE_IF_NOT_NULL(def); + } + + return return_code; +} + +static void destruct_xhc_loc_def(void *data) { + OBJ_RELEASE(* (xhc_loc_def_t **) data); +} + +static int conv_xhc_loc_def_combination(char *str, void *result) { + xhc_loc_def_t **defs; + int ndefs; + + int status = parse_csv(str, '+', 0, 0, (void **) &defs, + &ndefs, sizeof(xhc_loc_def_t *), conv_xhc_loc_def, + destruct_xhc_loc_def, NULL); + if(status != OMPI_SUCCESS) { + return status; + } + + opal_list_t *def_list = (opal_list_t *) result; + OBJ_CONSTRUCT(def_list, opal_list_t); + + for(int i = 0; i < ndefs; i++) { + opal_list_append(def_list, (opal_list_item_t *) defs[i]); + } + + free(defs); + + return OMPI_SUCCESS; +} + +static void destruct_xhc_loc_def_combination(void *data) { + OPAL_LIST_DESTRUCT((opal_list_t *) data); +} + +int mca_coll_xhc_component_parse_hierarchy(const char *val_str, + opal_list_t **level_defs_dst, int *nlevel_defs_dst) { + + /* The hierarchy is in a comma-separated list format. Each item in the + * list specifies how to group ranks, and each different item entails + * a grouping step. + * + * Each item in this list is a '+'-separated list. Of course, this can + * be just one item, without any delimiter, specifying the locality to + * follow for the grouping (e.g. numa, socket, etc). + * + * But, it can also be more complex (multiple '+'-separated items), used + * to describe virtual hierarchies. This allows to group different ranks + * in different ways, e.g. some ranks according to numa, then others by + * something else, etc. + * + * Each item in this '+'-separated list, can be of the following types: + * 1. A "named locality", e.g. hwloc's localities (only ones currently + * available), see hwloc_topo_str[]. + * 2. A list of ranks that should be grouped together. This is a comma- + * separated list of integers, enclosed in [] (I know, list-ception!). + * It may also contain range operators (..), to select multiple ranks + * at once (e.g. 0..3 expands to 0,1,2,3). Example: [0..15,20,22]. + * The order of the ranks does not matter. + * + * Finally, each such item may be suffixed by a special modifier: + * 1. The split modifier (:) specifies to group according to the + * locality it refers to, but to split each such group into multiple + * parts. E.g. the locality 'numa:2' will group ranks into half-numas + * group, such that for each NUMA node, half the ranks are in one + * group, and the rest are in another. + * 2. The max-ranks modifier (?) works similarly to the split modifier, + * only that it specifies that at most _n_ ranks should be placed in + * each group. If more than _n_ ranks share the locality the modifier + * refers to, multiple groups will be created for these ranks, each one + * not more than _n_ ranks in size. + * 3. The repeat modifier (*), which can be specified along with the two + * previous modifiers, allows manual control over the repetition of + * named localities. See below, under 'repetition'. + * + * Repetition: + * Named localities are repeated for all distinct rank clusters. For + * example, "numa", even though it is a single key, means to group + * all ranks that are in the same NUMA together, which will lead to + * multiple groups if multiple NUMA nodes are present. This is in + * contract to rank lists, which only create a single group, containing + * the ranks specified in it. The different items in the '+'-separated + * list are consumed in-order left-to-right, and any named localities + * are automatically repeated to apply all ranks that are not included + * in other items. When multiple named localities are present one after + * the other, the last one is repeated, unless another repetition was + * explicitly requested via the repeat modifier. + * + * Examples: + * "numa": Group according to numa locality + * "numa,socket": Group according to numa and then socket locality + * "node"/"flat": Group according to node locality -> all ranks in + * same node -> flat hierarchy i.e. none at all + * + * "numa:2,socket": Group according to numa locality but with two + * groups per NUMA, and then according to socket. + * "numa:2,numa,socket": Similar to the previous one, but this case + * will result in one of the two half-numa-leaders further becoming + * the leader of the NUMA node. + * "numa?10,socket": Group according to numa, but no more than 10 ranks + * per NUMA; create multiple groups if necessary. Then group according + * to socket. + * + * "[0..9]+[10..24]": Create 2 groups: one for the first 10 ranks, + * and another for the next 15 ones. + * "[0..39]+numa,socket": Group the first 40 ranks, and the rest + * according to numa locality. Then group according to socket. + * + * "socket+socket:2": Create at least two groups: one for all ranks + * in the first socket, and all the other ranks group them according + * to socket locality, but with two groups for each socket. + * "socket*+socket:2": Similar to the previous one, but only the last + * socket is split into two groups, all the other ranks are grouped + * according to socket locality. + * + * If the top-most locality specified does not cover all ranks, one such + * locality will automatically be added (in the hierarchy sort method). + * + * (Oh god what have I done! -Frankenstein, probably) */ + + int status = parse_csv(val_str, ',', '[', ']', (void **) level_defs_dst, + nlevel_defs_dst, sizeof(opal_list_t), conv_xhc_loc_def_combination, + destruct_xhc_loc_def_combination, "bad-hierarchy-item"); + + return status; +} + +static int conv_chunk_size(char *str, void *result) { + size_t last_idx = strlen(str) - 1; + char saved_char = str[last_idx]; + + size_t mult = 1; + + switch(str[last_idx]) { + case 'g': case 'G': + mult *= 1024; + case 'm': case 'M': + mult *= 1024; + case 'k': case 'K': + mult *= 1024; + + str[last_idx] = '\0'; + } + + bool legal = (str[0] != '\0'); + + for(char *c = str; *c; c++) { + if((*c < '0' || *c > '9') && *c != '-') { + legal = false; + break; + } + } + + if(legal) { + long long num = atoll(str) * mult; + * (size_t *) result = (size_t) (num > 0 ? num : -1); + } + + str[last_idx] = saved_char; + + return (legal ? OMPI_SUCCESS : OMPI_ERR_BAD_PARAM); +} + +int mca_coll_xhc_component_parse_chunk_sizes(const char *val_str, + size_t **chunks_dst, int *len_dst) { + + if(val_str == NULL) { + *chunks_dst = malloc(sizeof(size_t)); + if(*chunks_dst == NULL) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + + (*chunks_dst)[0] = (size_t) -1; + *len_dst = 1; + + return OMPI_SUCCESS; + } + + int status = parse_csv(val_str, ',', 0, 0, (void **) chunks_dst, len_dst, + sizeof(size_t), conv_chunk_size, NULL, "bad-chunk-size-item"); + + return status; +} diff --git a/ompi/mca/coll/xhc/coll_xhc_module.c b/ompi/mca/coll/xhc/coll_xhc_module.c new file mode 100644 index 00000000000..879e521f662 --- /dev/null +++ b/ompi/mca/coll/xhc/coll_xhc_module.c @@ -0,0 +1,721 @@ +/* + * Copyright (c) 2021-2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include +#include + +#include "mpi.h" + +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/coll/base/base.h" +#include "opal/mca/smsc/smsc.h" + +#include "opal/util/arch.h" +#include "opal/util/show_help.h" +#include "opal/util/minmax.h" + +#include "coll_xhc.h" + +static int xhc_module_save_fallback_fns( + xhc_module_t *module, ompi_communicator_t *comm); + +static int xhc_module_create_hierarchy(mca_coll_xhc_module_t *module, + ompi_communicator_t *comm, opal_list_t *level_defs, int nlevel_defs, + xhc_loc_t **hierarchy_dst, int *hierarchy_len_dst); + +static int xhc_module_sort_hierarchy(mca_coll_xhc_module_t *module, + ompi_communicator_t *comm, xhc_loc_t **hierarchy_dst, int *hierarchy_len_dst); + +// ----------------------------- + +static void xhc_module_clear(xhc_module_t *module) { + memset(&module->prev_colls, 0, sizeof(module->prev_colls)); + + module->comm_size = 0; + module->rank = -1; + + module->hierarchy_string = NULL; + module->hierarchy = NULL; + module->hierarchy_len = 0; + + module->chunks = NULL; + module->chunks_len = 0; + + module->rbuf = NULL; + module->rbuf_size = 0; + + module->peer_info = NULL; + module->data = NULL; + module->init = false; +} + +static void mca_coll_xhc_module_construct(mca_coll_xhc_module_t *module) { + xhc_module_clear(module); +} + +static void mca_coll_xhc_module_destruct(mca_coll_xhc_module_t *module) { + xhc_fini(module); + + free(module->hierarchy_string); + free(module->hierarchy); + free(module->chunks); + free(module->rbuf); + free(module->peer_info); + + xhc_module_clear(module); +} + +OBJ_CLASS_INSTANCE(mca_coll_xhc_module_t, mca_coll_base_module_t, + mca_coll_xhc_module_construct, mca_coll_xhc_module_destruct); + +// ----------------------------- + +mca_coll_base_module_t *mca_coll_xhc_module_comm_query(ompi_communicator_t *comm, + int *priority) { + + if((*priority = mca_coll_xhc_component.priority) < 0) { + return NULL; + } + + if(OMPI_COMM_IS_INTER(comm) || ompi_comm_size(comm) == 1 + || ompi_group_have_remote_peers (comm->c_local_group)) { + + opal_output_verbose(MCA_BASE_VERBOSE_COMPONENT, + ompi_coll_base_framework.framework_output, + "coll:xhc:comm_query (%s/%s): intercomm, self-comm, " + "or not all ranks local; disqualifying myself", + ompi_comm_print_cid(comm), comm->c_name); + + return NULL; + } + + int comm_size = ompi_comm_size(comm); + for(int r = 0; r < comm_size; r++) { + ompi_proc_t *proc = ompi_comm_peer_lookup(comm, r); + + if(proc->super.proc_arch != opal_local_arch) { + opal_output_verbose(MCA_BASE_VERBOSE_COMPONENT, + ompi_coll_base_framework.framework_output, + "coll:xhc:comm_query (%s/%s): All ranks not of the same arch; " + "disabling myself", ompi_comm_print_cid(comm), comm->c_name); + + return NULL; + } + } + + mca_coll_base_module_t *module = + (mca_coll_base_module_t *) OBJ_NEW(mca_coll_xhc_module_t); + + if(module == NULL) { + return NULL; + } + + module->coll_module_enable = mca_coll_xhc_module_enable; + module->coll_module_disable = mca_coll_xhc_module_disable; + + module->coll_barrier = mca_coll_xhc_barrier; + + if(mca_smsc == NULL) { + opal_output_verbose(MCA_BASE_VERBOSE_COMPONENT, + ompi_coll_base_framework.framework_output, + "coll:xhc: Warning: No opal/smsc support found; " + "only barrier will be enabled"); + + return module; + } + + module->coll_bcast = mca_coll_xhc_bcast; + + if(!mca_smsc_base_has_feature(MCA_SMSC_FEATURE_CAN_MAP)) { + opal_output_verbose(MCA_BASE_VERBOSE_COMPONENT, + ompi_coll_base_framework.framework_output, + "coll:xhc: Warning: opal/smsc module is not CAN_MAP capable; " + "(all)reduce will be disabled, bcast might see reduced performance"); + + return module; + } + + module->coll_allreduce = mca_coll_xhc_allreduce; + module->coll_reduce = mca_coll_xhc_reduce; + + return module; +} + +#define COLL_FN_HELPER(_m, _api) .coll_ ## _api = (_m)->coll_ ## _api, \ + .coll_ ## _api ## _module = (_m) + +int mca_coll_xhc_module_enable(mca_coll_base_module_t *ompi_module, + ompi_communicator_t *comm) { + + xhc_module_t *module = (xhc_module_t *) ompi_module; + + int ret; + + // --- + + ret = xhc_module_save_fallback_fns(module, comm); + + /* This can/will happen often (see #9885), but theoretically + * isn't a problem, as in these cases the component wouldn't + * end up getting used anyway. */ + if(ret != OMPI_SUCCESS) { + opal_output_verbose(MCA_BASE_VERBOSE_COMPONENT, + ompi_coll_base_framework.framework_output, + "coll:xhc:module_enable (%s/%s): No previous fallback component " + "found; disabling myself", ompi_comm_print_cid(comm), comm->c_name); + + return ret; + } + + // --- + + module->comm_size = ompi_comm_size(comm); + module->rank = ompi_comm_rank(comm); + + module->peer_info = calloc(module->comm_size, sizeof(xhc_peer_info_t)); + + for(int r = 0; r < module->comm_size; r++) { + ompi_proc_t *peer_proc = ompi_comm_peer_lookup(comm, r); + + module->peer_info[r].proc = peer_proc; + module->peer_info[r].locality = peer_proc->super.proc_flags; + } + + module->peer_info[module->rank].locality |= + ((1 << OMPI_XHC_LOC_EXT_BITS) - 1) << OMPI_XHC_LOC_EXT_START; + + // --- + + /* This needs to happen here, and we need to save the hierarchy string, + * because the info value will have been gone by the time lazy_init is + * called. Furthermore, we can't prepeare the hierarchy here, as it might + * required communication (allgather) with the other ranks. */ + + const char *hier_mca = mca_coll_xhc_component.hierarchy_mca; + + opal_cstring_t *hier_info; + int hier_info_flag = 0; + + if(comm->super.s_info != NULL) { + opal_info_get(comm->super.s_info, "ompi_comm_coll_xhc_hierarchy", + &hier_info, &hier_info_flag); + + if(hier_info_flag) { + hier_mca = hier_info->string; + } + } + + module->hierarchy_string = strdup(hier_mca); + + if(hier_info_flag) { + OBJ_RELEASE(hier_info); + } + + if(!module->hierarchy_string) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + + // --- + + ret = xhc_component_parse_chunk_sizes(mca_coll_xhc_component.chunk_size_mca, + &module->chunks, &module->chunks_len); + if(ret != OMPI_SUCCESS) { + return ret; + } + + // --- + + xhc_coll_fns_t xhc_fns = (xhc_coll_fns_t) { + COLL_FN_HELPER(ompi_module, allreduce), + COLL_FN_HELPER(ompi_module, barrier), + COLL_FN_HELPER(ompi_module, bcast), + COLL_FN_HELPER(ompi_module, reduce) + }; + + xhc_module_install_fns(module, comm, xhc_fns); + + return OMPI_SUCCESS; +} + +int mca_coll_xhc_module_disable(mca_coll_base_module_t *ompi_module, + ompi_communicator_t *comm) { + + xhc_module_t *module = (xhc_module_t *) ompi_module; + + xhc_module_install_fallback_fns(module, comm, NULL); + mca_coll_xhc_module_destruct(module); + + return OMPI_SUCCESS; +} + +// ----------------------------- + +#define SAVE_FALLBACK_COLL(_comm, _m, _dst, _api) do { \ + if((_m)->coll_ ## _api) { \ + MCA_COLL_SAVE_API(_comm, _api, (_dst).coll_ ## _api, \ + (_dst).coll_ ## _api ## _module, "xhc"); \ + \ + if(!(_dst).coll_ ## _api || !(_dst).coll_ ## _api ## _module) { \ + _save_status = OMPI_ERR_NOT_FOUND; \ + } \ + } \ +} while(0) + +#define INSTALL_FALLBACK_COLL(_comm, _m, _saved, _new, _api) do { \ + if((_comm)->c_coll->coll_ ## _api ## _module == (_m)) { \ + MCA_COLL_SAVE_API(_comm, _api, (_saved).coll_ ## _api, \ + (_saved).coll_ ## _api ## _module, "xhc"); \ + MCA_COLL_INSTALL_API(_comm, _api, (_new).coll_ ## _api, \ + (_new).coll_ ## _api ## _module, "xhc"); \ + } \ +} while(0) + +#define INSTALL_COLL(_comm, _src, _api) do { \ + if((_src).coll_ ## _api) { \ + MCA_COLL_INSTALL_API(_comm, _api, (_src).coll_ ## _api, \ + (_src).coll_ ## _api ## _module, "xhc"); \ + } \ +} while(0) + +/* Save the function pointers of the previous module, in XHC's + * struct. Only the functions that XHC will provide are saved. */ +static int xhc_module_save_fallback_fns( + xhc_module_t *module, ompi_communicator_t *comm) { + + mca_coll_base_module_t *ompi_module = (mca_coll_base_module_t *) module; + + xhc_coll_fns_t colls = {0}; + int _save_status = OMPI_SUCCESS; + + SAVE_FALLBACK_COLL(comm, ompi_module, colls, allreduce); + SAVE_FALLBACK_COLL(comm, ompi_module, colls, barrier); + SAVE_FALLBACK_COLL(comm, ompi_module, colls, bcast); + SAVE_FALLBACK_COLL(comm, ompi_module, colls, reduce); + + if(_save_status == OMPI_SUCCESS) { + module->prev_colls = colls; + } + + return _save_status; +} + +/* Replace XHC's pointers in c_coll with those from the fallback + * component saved earlier. XHC's pointers are conveniently returned + * in prev_fns_dst, to later pass to xhc_module_install_fns. */ +void mca_coll_xhc_module_install_fallback_fns(xhc_module_t *module, + ompi_communicator_t *comm, xhc_coll_fns_t *prev_fns_dst) { + + mca_coll_base_module_t *ompi_module = (mca_coll_base_module_t *) module; + + xhc_coll_fns_t saved = {0}; + + INSTALL_FALLBACK_COLL(comm, ompi_module, saved, module->prev_colls, allreduce); + INSTALL_FALLBACK_COLL(comm, ompi_module, saved, module->prev_colls, barrier); + INSTALL_FALLBACK_COLL(comm, ompi_module, saved, module->prev_colls, bcast); + INSTALL_FALLBACK_COLL(comm, ompi_module, saved, module->prev_colls, reduce); + + if(prev_fns_dst) { + *prev_fns_dst = saved; + } +} + +/* */ +void mca_coll_xhc_module_install_fns(xhc_module_t *module, + ompi_communicator_t *comm, xhc_coll_fns_t fns) { + + (void) module; + + INSTALL_COLL(comm, fns, allreduce); + INSTALL_COLL(comm, fns, barrier); + INSTALL_COLL(comm, fns, bcast); + INSTALL_COLL(comm, fns, reduce); +} + +// ----------------------------- + +int mca_coll_xhc_module_prepare_hierarchy( + xhc_module_t *module, ompi_communicator_t *comm) { + + int ret; + + opal_list_t *level_defs; + int nlevel_defs; + + ret = xhc_component_parse_hierarchy(module->hierarchy_string, + &level_defs, &nlevel_defs); + if(ret != OMPI_SUCCESS) { + return ret; + } + + ret = xhc_module_create_hierarchy(module, comm, level_defs, + nlevel_defs, &module->hierarchy, &module->hierarchy_len); + if(ret != OMPI_SUCCESS) { + return ret; + } + + for(int i = 0; i < nlevel_defs; i++) + OPAL_LIST_DESTRUCT(&level_defs[i]); + free(level_defs); + + ret = xhc_module_sort_hierarchy(module, comm, + &module->hierarchy, &module->hierarchy_len); + if(ret != OMPI_SUCCESS) { + return ret; + } + + return OMPI_SUCCESS; +} + +static int xhc_module_create_hierarchy(xhc_module_t *module, + ompi_communicator_t *comm, opal_list_t *level_defs, int nlevel_defs, + xhc_loc_t **hierarchy_dst, int *hierarchy_len_dst) { + + xhc_peer_info_t *peer_info = module->peer_info; + + int comm_size = ompi_comm_size(comm); + int rank = ompi_comm_rank(comm); + + xhc_loc_t *hierarchy = NULL; + int nvirt_hiers = 0; + + int *rank_list; + + opal_hwloc_locality_t *loc_list; + ompi_datatype_t *hwloc_locality_type = NULL; + + int ret, return_code = OMPI_SUCCESS; + + hierarchy = malloc(nlevel_defs * sizeof(xhc_loc_t)); + rank_list = malloc(comm_size * sizeof(int)); + loc_list = malloc(comm_size * sizeof(opal_hwloc_locality_t)); + + if(!hierarchy || !rank_list || !loc_list) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_OUT_OF_RESOURCE, end); + } + + switch(sizeof(opal_hwloc_locality_t)) { + case 1: hwloc_locality_type = MPI_UINT8_T; break; + case 2: hwloc_locality_type = MPI_UINT16_T; break; + case 4: hwloc_locality_type = MPI_UINT32_T; break; + case 8: hwloc_locality_type = MPI_UINT64_T; break; + } + assert(hwloc_locality_type); + + for(int h = 0; h < nlevel_defs; h++) { + opal_list_t *defs = &level_defs[h]; + + xhc_loc_def_t *my_def = NULL; + xhc_loc_t locality; + + xhc_loc_def_t *def_0 = (xhc_loc_def_t *) opal_list_get_first(defs); + + bool is_virtual = (opal_list_get_size(defs) > 1 || def_0->rank_list + || def_0->split > 1 || def_0->max_ranks > 0); + + if(is_virtual) { + if(nvirt_hiers == OMPI_XHC_LOC_EXT_BITS) { + opal_output_verbose(MCA_BASE_VERBOSE_COMPONENT, + ompi_coll_base_framework.framework_output, + "coll:xhc: Error: Too many virtual hierarchies"); + + RETURN_WITH_ERROR(return_code, OMPI_ERR_NOT_SUPPORTED, end); + } + + locality = 1 << (OMPI_XHC_LOC_EXT_START + nvirt_hiers); + nvirt_hiers++; + } else { + locality = def_0->named_loc; + } + + hierarchy[h] = locality; + def_0 = NULL; + + xhc_loc_def_t *def, *def_next; + + /* Handle rank lists; take note if I belong + * in one, and remove them from the mix */ + OPAL_LIST_FOREACH_SAFE(def, def_next, defs, xhc_loc_def_t) { + if(def->rank_list) { + if(!my_def) { + for(int rl = 0; rl < def->rank_list_len; rl++) { + if(rank >= def->rank_list[rl].start_rank + && rank <= def->rank_list[rl].end_rank) { + my_def = def; + break; + } + } + } + + opal_list_remove_item(defs, (opal_list_item_t *) def); + if(def != my_def) { + OBJ_RELEASE(def); + } + } + } + + bool dir_fwd = true; + + /* When multiple locality defitions are present, they are assigned + * to groups in a left-to-right fashion. At every turn, the first + * rank (determined by the minimum ID) that's still not part of + * a locality, as well as the other ranks that are local with it, + * claim/consume the next locality from the list. The direction + * serves to implement the repeat modifier. When it is located, + * the process starts taking place right-to-left following the max + * ID. At the end and after the loop, the repeated locality will + * be the only one left and all remaining ranks will follow it. */ + while(opal_list_get_size(defs) > 1) { + def = (xhc_loc_def_t *) (dir_fwd ? opal_list_get_first(defs) + : opal_list_get_last(defs)); + + if(dir_fwd && def->repeat) { + dir_fwd = false; + continue; + } + + int ticket = (my_def == NULL ? rank : (dir_fwd ? comm_size : -1)); + int chosen; + + ret = comm->c_coll->coll_allreduce(&ticket, &chosen, 1, + MPI_INT, (dir_fwd ? MPI_MIN : MPI_MAX), comm, + comm->c_coll->coll_allreduce_module); + if(ret != OMPI_SUCCESS) { + RETURN_WITH_ERROR(return_code, ret, end); + } + + if(chosen >= 0 && chosen < comm_size + && PEER_IS_LOCAL(peer_info, chosen, def->named_loc)) { + + my_def = def; + } + + opal_list_remove_item(defs, (opal_list_item_t *) def); + if(def != my_def) { + OBJ_RELEASE(def); + } + } + + if(opal_list_get_size(defs) > 0 && !my_def) { + my_def = (xhc_loc_def_t *) opal_list_get_first(defs); + opal_list_remove_item(defs, (opal_list_item_t *) my_def); + } + + /* Share which named locality each rank follows; ranks that + * follow different localities shouldn't be grouped together */ + opal_hwloc_locality_t follow_loc = (my_def ? my_def->named_loc : 0); + ret = comm->c_coll->coll_allgather(&follow_loc, 1, + hwloc_locality_type, loc_list, 1, hwloc_locality_type, + comm, comm->c_coll->coll_allgather_module); + if(ret != OMPI_SUCCESS) { + RETURN_WITH_ERROR(return_code, ret, end); + } + + if(my_def == NULL) { + continue; + } + + int member_id; + int members = 0; + + // If working with rank list, set the ranks from the list as "local" + if(my_def->rank_list) { + for(int i = 0; i < my_def->rank_list_len; i++) { + for(int r = my_def->rank_list[i].start_rank; + r <= my_def->rank_list[i].end_rank && r < comm_size; r++) { + if(r == rank) { + member_id = members; + } + + peer_info[r].locality |= locality; + rank_list[members++] = r; + } + } + } else if(is_virtual) { + /* We might have a named locality instead of a rank list, but if + * we still needed to create a virtual one, we need to apply it */ + for(int r = 0; r < comm_size; r++) { + if(loc_list[r] != my_def->named_loc) { + continue; + } + + if(!PEER_IS_LOCAL(peer_info, r, my_def->named_loc)) { + continue; + } + + if(r == rank) { + member_id = members; + } + + peer_info[r].locality |= locality; + rank_list[members++] = r; + } + } + + /* If split or max ranks was specified, math partition the locality + * and remove the previously added locality mapping to some ranks */ + if(my_def->split > 1) { + int piece_size = members / my_def->split; + int leftover = members % my_def->split; + + for(int m = 0, next_border = 0; m < members; m++) { + if(m == next_border) { + next_border += piece_size + (leftover > 0 ? 1 : 0); + if(leftover > 0) { + leftover--; + } + + if(member_id >= m && member_id < next_border) { + m = next_border - 1; + continue; + } + } + + peer_info[rank_list[m]].locality &= ~locality; + } + } else if(my_def->max_ranks > 1) { + for(int m = 0; m < members; m++) { + if(m % my_def->max_ranks == 0) { + if(member_id >= m && member_id - m < my_def->max_ranks) { + m += my_def->max_ranks - 1; + continue; + } + } + + peer_info[rank_list[m]].locality &= ~locality; + } + } + + OBJ_RELEASE_IF_NOT_NULL(my_def); + } + + *hierarchy_dst = hierarchy; + *hierarchy_len_dst = nlevel_defs; + +end: + + free(rank_list); + + if(return_code != OMPI_SUCCESS) { + free(hierarchy); + } + + return return_code; +} + +static int xhc_module_sort_hierarchy(xhc_module_t *module, + ompi_communicator_t *comm, xhc_loc_t **hierarchy_dst, + int *hierarchy_len_dst) { + + xhc_peer_info_t *peer_info = module->peer_info; + int comm_size = ompi_comm_size(comm); + + xhc_loc_t *old_hier = *hierarchy_dst; + int hier_len = *hierarchy_len_dst; + + xhc_loc_t *new_hier = NULL; + bool *hier_done = NULL; + + int return_code = OMPI_SUCCESS; + + new_hier = malloc((hier_len + 1) * sizeof(xhc_loc_t)); + hier_done = calloc(hier_len, sizeof(bool)); + + if(new_hier == NULL || hier_done == NULL) { + RETURN_WITH_ERROR(return_code, OMPI_ERR_OUT_OF_RESOURCE, end); + } + + bool has_virtual = false; + for(int i = 0; i < hier_len; i++) { + if(old_hier[i] >= (1 << OMPI_XHC_LOC_EXT_START)) { + has_virtual = true; + break; + } + } + + /* If any virtual hierarchy is involved, attempting to sort it is likely + * asking for trouble. Skip the sorting, and only consider adding a top + * common locality. There is a chance it wasn't actually necessary, but + * it never hurts. */ + + if(has_virtual) { + memcpy(new_hier, old_hier, hier_len * sizeof(xhc_loc_t)); + } else { + for(int new_idx = hier_len - 1; new_idx >= 0; new_idx--) { + int max_matches_count = -1; + int max_matches_hier_idx = -1; + + for(int i = 0; i < hier_len; i++) { + if(hier_done[i]) { + continue; + } + + int matches = 0; + + for(int r = 0; r < comm_size; r++) { + if(PEER_IS_LOCAL(peer_info, r, old_hier[i])) { + matches++; + } + } + + if(matches > max_matches_count) { + max_matches_count = matches; + max_matches_hier_idx = i; + } + } + + assert(max_matches_count != -1); + + new_hier[new_idx] = old_hier[max_matches_hier_idx]; + hier_done[max_matches_hier_idx] = true; + } + } + + xhc_loc_t common_locality = (xhc_loc_t) -1; + + for(int r = 0; r < comm_size; r++) { + ompi_proc_t *proc = ompi_comm_peer_lookup(comm, r); + common_locality &= proc->super.proc_flags; + } + + if(common_locality == 0) { + opal_output_verbose(MCA_BASE_VERBOSE_COMPONENT, + ompi_coll_base_framework.framework_output, + "coll:xhc: Error: There is no locality common " + "to all ranks in the communicator"); + + RETURN_WITH_ERROR(return_code, OMPI_ERR_NOT_SUPPORTED, end); + } + + if(hier_len == 0 || (common_locality & new_hier[hier_len - 1]) + != new_hier[hier_len - 1]) { + + new_hier[hier_len] = common_locality; + hier_len++; + } + + REALLOC(new_hier, hier_len, xhc_loc_t); + + free(old_hier); + + *hierarchy_dst = new_hier; + *hierarchy_len_dst = hier_len; + +end: + + free(hier_done); + + if(return_code != OMPI_SUCCESS) { + free(new_hier); + } + + return return_code; +} diff --git a/ompi/mca/coll/xhc/coll_xhc_reduce.c b/ompi/mca/coll/xhc/coll_xhc_reduce.c new file mode 100644 index 00000000000..5f28986fb66 --- /dev/null +++ b/ompi/mca/coll/xhc/coll_xhc_reduce.c @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021-2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" +#include "mpi.h" + +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/communicator/communicator.h" +#include "ompi/op/op.h" + +#include "opal/mca/rcache/base/base.h" +#include "opal/util/show_help.h" +#include "opal/util/minmax.h" + +#include "coll_xhc.h" + +int mca_coll_xhc_reduce(const void *sbuf, void *rbuf, + int count, ompi_datatype_t *datatype, ompi_op_t *op, int root, + ompi_communicator_t *ompi_comm, mca_coll_base_module_t *ompi_module) { + + xhc_module_t *module = (xhc_module_t *) ompi_module; + + // Currently, XHC's reduce only supports root = 0 + if(root == 0) { + return xhc_allreduce_internal(sbuf, rbuf, count, + datatype, op, ompi_comm, ompi_module, false); + } else { + xhc_coll_fns_t fallback = module->prev_colls; + + return fallback.coll_reduce(sbuf, rbuf, count, datatype, + op, root, ompi_comm, fallback.coll_reduce_module); + } +} diff --git a/ompi/mca/coll/xhc/help-coll-xhc.txt b/ompi/mca/coll/xhc/help-coll-xhc.txt new file mode 100644 index 00000000000..453a96df4fc --- /dev/null +++ b/ompi/mca/coll/xhc/help-coll-xhc.txt @@ -0,0 +1,24 @@ +# +# Copyright (c) 2021-2023 Computer Architecture and VLSI Systems (CARV) +# Laboratory, ICS Forth. All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +[bad-hierarchy-item] +WARNING (coll/xhc) +Unrecognized locality definition '%s' in hierarchy parameter string '%s' +The component won't load +# +[bad-chunk-size-item] +WARNING (coll/xhc) +Malformed item '%s' in chunk size parameter string '%s' +The component won't load +# +[xhc-init-failed] +WARNING (coll/xhc) +Component initialization failed with error code %d +Errno: %d (%s) diff --git a/ompi/mca/coll/xhc/resources/xhc-hierarchy.svg b/ompi/mca/coll/xhc/resources/xhc-hierarchy.svg new file mode 100755 index 00000000000..c8f6d8a2da3 --- /dev/null +++ b/ompi/mca/coll/xhc/resources/xhc-hierarchy.svg @@ -0,0 +1,1176 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + NUMA Level + + Socket Level + + + + + + + + System Level + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Cores + + + + + + + + + + NUMA 0Leader + + + + + + + + + P0 + P1 + P2 + P3 + P4 + P5 + P6 + P7 + P8 + P9 + 10 + 11 + 12 + 13 + 14 + 15 + + + P0 + + + + + P8 + + + + P12 + + + + P4 + + + + P0 + + + + P8 + + + + NUMA 1Leader + + + + NUMA 3Leader + + + + diff --git a/ompi/mpi/c/comm_get_info.c b/ompi/mpi/c/comm_get_info.c index 138f1656dcf..28bb8e776d6 100644 --- a/ompi/mpi/c/comm_get_info.c +++ b/ompi/mpi/c/comm_get_info.c @@ -61,7 +61,7 @@ int MPI_Comm_get_info(MPI_Comm comm, MPI_Info *info_used) opal_info_t *opal_info_used = &(*info_used)->super; - opal_info_dup(comm->super.s_info, &opal_info_used); + opal_info_dup_public(comm->super.s_info, &opal_info_used); return MPI_SUCCESS; } diff --git a/ompi/mpi/c/file_get_info.c b/ompi/mpi/c/file_get_info.c index 523602feb16..429da4af303 100644 --- a/ompi/mpi/c/file_get_info.c +++ b/ompi/mpi/c/file_get_info.c @@ -86,7 +86,7 @@ int MPI_File_get_info(MPI_File fh, MPI_Info *info_used) } opal_info_t *opal_info_used = &(*info_used)->super; - opal_info_dup(fh->super.s_info, &opal_info_used); + opal_info_dup_public(fh->super.s_info, &opal_info_used); return OMPI_SUCCESS; } diff --git a/ompi/mpi/c/win_get_info.c b/ompi/mpi/c/win_get_info.c index 7b842391735..a982b5986e1 100644 --- a/ompi/mpi/c/win_get_info.c +++ b/ompi/mpi/c/win_get_info.c @@ -60,7 +60,7 @@ int MPI_Win_get_info(MPI_Win win, MPI_Info *info_used) } opal_info_t *opal_info_used = &(*info_used)->super; - ret = opal_info_dup(win->super.s_info, &opal_info_used); + ret = opal_info_dup_public(win->super.s_info, &opal_info_used); OMPI_ERRHANDLER_RETURN(ret, win, ret, FUNC_NAME); } diff --git a/ompi/request/request_dbg.h b/ompi/request/request_dbg.h index e6aa3757d06..5929374ade4 100644 --- a/ompi/request/request_dbg.h +++ b/ompi/request/request_dbg.h @@ -1,6 +1,7 @@ /* -*- Mode: C; c-basic-offset:4 ; -*- */ /* * Copyright (c) 2009 Sun Microsystems, Inc. All rights reserved. + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -17,7 +18,7 @@ */ /** - * Enum inidicating the type of the request + * Enum indicating the type of the request */ typedef enum { OMPI_REQUEST_PML, /**< MPI point-to-point request */ diff --git a/ompi/win/win.c b/ompi/win/win.c index 70e70c978e8..2f0974ac016 100644 --- a/ompi/win/win.c +++ b/ompi/win/win.c @@ -266,9 +266,6 @@ ompi_win_create(void *base, size_t size, return ret; } - /* MPI-4 §12.2.7 requires us to remove all unknown keys from the info object */ - opal_info_remove_unreferenced(win->super.s_info); - *newwin = win; return OMPI_SUCCESS; @@ -300,9 +297,6 @@ ompi_win_allocate(size_t size, int disp_unit, opal_info_t *info, return ret; } - /* MPI-4 §12.2.7 requires us to remove all unknown keys from the info object */ - opal_info_remove_unreferenced(win->super.s_info); - *((void**) baseptr) = base; *newwin = win; @@ -335,9 +329,6 @@ ompi_win_allocate_shared(size_t size, int disp_unit, opal_info_t *info, return ret; } - /* MPI-4 §12.2.7 requires us to remove all unknown keys from the info object */ - opal_info_remove_unreferenced(win->super.s_info); - *((void**) baseptr) = base; *newwin = win; @@ -368,9 +359,6 @@ ompi_win_create_dynamic(opal_info_t *info, ompi_communicator_t *comm, ompi_win_t return ret; } - /* MPI-4 §12.2.7 requires us to remove all unknown keys from the info object */ - opal_info_remove_unreferenced(win->super.s_info); - *newwin = win; return OMPI_SUCCESS; diff --git a/opal/mca/btl/smcuda/btl_smcuda_component.c b/opal/mca/btl/smcuda/btl_smcuda_component.c index 78e06751222..4249b955bb2 100644 --- a/opal/mca/btl/smcuda/btl_smcuda_component.c +++ b/opal/mca/btl/smcuda/btl_smcuda_component.c @@ -198,8 +198,7 @@ static int smcuda_register(void) /* Lower priority when CUDA support is not requested */ if (0 != strcmp(opal_accelerator_base_selected_component.base_version.mca_component_name, "null")) { - - mca_btl_smcuda.super.btl_exclusivity = MCA_BTL_EXCLUSIVITY_HIGH + 1; + mca_btl_smcuda.super.btl_exclusivity = MCA_BTL_EXCLUSIVITY_DEFAULT; } else { mca_btl_smcuda.super.btl_exclusivity = MCA_BTL_EXCLUSIVITY_LOW; } diff --git a/opal/util/info.c b/opal/util/info.c index 1c73466f689..12981e8297c 100644 --- a/opal/util/info.c +++ b/opal/util/info.c @@ -65,14 +65,18 @@ OBJ_CLASS_INSTANCE(opal_info_entry_t, opal_list_item_t, info_entry_constructor, info_entry_destructor); /* - * Duplicate an info + * Duplicate an info into newinfo. If public_info is true we only duplicate + * key-value pairs that are not internal and that had been referenced, + * either through opal_info_get or opal_info_set. */ -int opal_info_dup(opal_info_t *info, opal_info_t **newinfo) +static int opal_info_dup_impl(opal_info_t *info, opal_info_t **newinfo, bool public_only) { opal_info_entry_t *iterator; OPAL_THREAD_LOCK(info->i_lock); OPAL_LIST_FOREACH (iterator, &info->super, opal_info_entry_t) { + /* skip keys that are internal if we didn't ask for them */ + if (public_only && (iterator->ie_internal || iterator->ie_referenced == 0)) continue; /* create a new info entry and retain the string objects */ opal_info_entry_t *newentry = OBJ_NEW(opal_info_entry_t); newentry->ie_key = iterator->ie_key; @@ -85,6 +89,16 @@ int opal_info_dup(opal_info_t *info, opal_info_t **newinfo) return OPAL_SUCCESS; } +int opal_info_dup_public(opal_info_t *info, opal_info_t **newinfo) +{ + return opal_info_dup_impl(info, newinfo, true); +} + +int opal_info_dup(opal_info_t *info, opal_info_t **newinfo) +{ + return opal_info_dup_impl(info, newinfo, false); +} + static void opal_info_get_nolock(opal_info_t *info, const char *key, opal_cstring_t **value, int *flag) { @@ -136,7 +150,7 @@ static int opal_info_set_cstring_nolock(opal_info_t *info, const char *key, opal return OPAL_SUCCESS; } -static int opal_info_set_nolock(opal_info_t *info, const char *key, const char *value) +static int opal_info_set_nolock(opal_info_t *info, const char *key, const char *value, bool internal) { opal_info_entry_t *old_info; @@ -147,6 +161,7 @@ static int opal_info_set_nolock(opal_info_t *info, const char *key, const char * */ size_t value_len = strlen(value); old_info->ie_referenced++; + old_info->ie_internal = internal; if (old_info->ie_value->length == value_len && 0 == strcmp(old_info->ie_value->string, value)) { return OPAL_SUCCESS; @@ -171,6 +186,7 @@ static int opal_info_set_nolock(opal_info_t *info, const char *key, const char * return OPAL_ERR_OUT_OF_RESOURCE; } new_info->ie_referenced++; + new_info->ie_internal = internal; opal_list_append(&(info->super), (opal_list_item_t *) new_info); } return OPAL_SUCCESS; @@ -184,7 +200,20 @@ int opal_info_set(opal_info_t *info, const char *key, const char *value) int ret; OPAL_THREAD_LOCK(info->i_lock); - ret = opal_info_set_nolock(info, key, value); + ret = opal_info_set_nolock(info, key, value, false); + OPAL_THREAD_UNLOCK(info->i_lock); + return ret; +} + +/* + * Set a value on the info + */ +int opal_info_set_internal(opal_info_t *info, const char *key, const char *value) +{ + int ret; + + OPAL_THREAD_LOCK(info->i_lock); + ret = opal_info_set_nolock(info, key, value, true); OPAL_THREAD_UNLOCK(info->i_lock); return ret; } @@ -372,6 +401,7 @@ static void info_entry_constructor(opal_info_entry_t *entry) entry->ie_key = NULL; entry->ie_value = NULL; entry->ie_referenced = 0; + entry->ie_internal = false; } static void info_entry_destructor(opal_info_entry_t *entry) @@ -410,52 +440,3 @@ static opal_info_entry_t *info_find_key(opal_info_t *info, const char *key) } return NULL; } - -/** - * Mark the entry \c key as referenced. - */ -int opal_info_mark_referenced(opal_info_t *info, const char *key) -{ - opal_info_entry_t *entry; - - OPAL_THREAD_LOCK(info->i_lock); - entry = info_find_key(info, key); - entry->ie_referenced++; - OPAL_THREAD_UNLOCK(info->i_lock); - - return OPAL_SUCCESS; -} - -/** - * Remove a reference from the entry \c key. - */ -int opal_info_unmark_referenced(opal_info_t *info, const char *key) -{ - opal_info_entry_t *entry; - - OPAL_THREAD_LOCK(info->i_lock); - entry = info_find_key(info, key); - entry->ie_referenced--; - OPAL_THREAD_UNLOCK(info->i_lock); - - return OPAL_SUCCESS; -} - -/** - * Remove any entries that are not marked as referenced - */ -int opal_info_remove_unreferenced(opal_info_t *info) -{ - opal_info_entry_t *iterator, *next; - /* iterate over all entries and remove the ones that are not referenced */ - OPAL_THREAD_LOCK(info->i_lock); - OPAL_LIST_FOREACH_SAFE (iterator, next, &info->super, opal_info_entry_t) { - if (!iterator->ie_referenced) { - opal_list_remove_item(&info->super, &iterator->super); - } - } - OPAL_THREAD_UNLOCK(info->i_lock); - - - return OPAL_SUCCESS; -} diff --git a/opal/util/info.h b/opal/util/info.h index 8faee5170f0..5539b09d584 100644 --- a/opal/util/info.h +++ b/opal/util/info.h @@ -67,6 +67,7 @@ struct opal_info_entry_t { opal_cstring_t *ie_key; /**< "key" part of the (key, value) pair */ uint32_t ie_referenced; /**< number of times this entry was internally referenced */ + bool ie_internal; /**< internal keys are not handed back to the user */ }; /** @@ -90,7 +91,23 @@ OPAL_DECLSPEC OBJ_CLASS_DECLARATION(opal_info_t); OPAL_DECLSPEC OBJ_CLASS_DECLARATION(opal_info_entry_t); /** - * opal_info_dup - Duplicate an 'MPI_Info' object + * opal_info_dup - Duplicate public keys of an 'MPI_Info' object + * + * @param info source info object (handle) + * @param newinfo pointer to the new info object (handle) + * + * @retval OPAL_SUCCESS upon success + * @retval OPAL_ERR_OUT_OF_RESOURCE if out of memory + * + * Not only will the (key, value) pairs be duplicated, the order + * of keys will be the same in 'newinfo' as it is in 'info'. When + * an info object is no longer being used, it should be freed with + * \c opal_info_free. + */ +int opal_info_dup_public(opal_info_t *info, opal_info_t **newinfo); + +/** + * opal_info_dup - Duplicate all entries of an 'MPI_Info' object * * @param info source info object (handle) * @param newinfo pointer to the new info object (handle) @@ -117,6 +134,18 @@ int opal_info_dup(opal_info_t *info, opal_info_t **newinfo); */ OPAL_DECLSPEC int opal_info_set(opal_info_t *info, const char *key, const char *value); +/** + * Set a new key,value pair on info and mark it as internal. + * + * @param info pointer to opal_info_t object + * @param key pointer to the new key object + * @param value pointer to the new value object + * + * @retval OPAL_SUCCESS upon success + * @retval OPAL_ERR_OUT_OF_RESOURCE if out of memory + */ +OPAL_DECLSPEC int opal_info_set_internal(opal_info_t *info, const char *key, const char *value); + /** * Set a new key,value pair on info. * @@ -287,43 +316,6 @@ static inline int opal_info_get_nkeys(opal_info_t *info, int *nkeys) return OPAL_SUCCESS; } - -/** - * Mark the entry \c key as referenced. - * - * This function is useful for lazily initialized components - * that do not read the key immediately but want to make sure - * the key is kept by the object owning the info key. - * - * @param info Pointer to opal_info_t object. - * @param key The key which to mark as referenced. - * - * @retval OPAL_SUCCESS - */ -int opal_info_mark_referenced(opal_info_t *info, const char *key); - -/** - * Remove a reference from the entry \c key. - * - * This function should be used by components reading the key - * without wanting to retain it in the object owning the info. - * - * @param info Pointer to opal_info_t object. - * @param key The key which to unmark as referenced. - * - * @retval OPAL_SUCCESS - */ -int opal_info_unmark_referenced(opal_info_t *info, const char *key); - -/** - * Remove any entries that are not marked as referenced - * - * @param info Pointer to opal_info_t object. - * - * @retval OPAL_SUCCESS - */ -int opal_info_remove_unreferenced(opal_info_t *info); - END_C_DECLS #endif /* OPAL_INFO_H */ diff --git a/opal/util/info_subscriber.c b/opal/util/info_subscriber.c index 68dc7ef0871..3382612ac17 100644 --- a/opal/util/info_subscriber.c +++ b/opal/util/info_subscriber.c @@ -269,9 +269,10 @@ int opal_infosubscribe_change_info(opal_infosubscriber_t *object, opal_info_t *n updated_value = opal_infosubscribe_inform_subscribers(object, iterator->ie_key->string, iterator->ie_value->string, &found_callback); - if (NULL != updated_value - && 0 != strncmp(updated_value, value_str->string, value_str->length)) { - err = opal_info_set(object->s_info, iterator->ie_key->string, updated_value); + if (NULL != updated_value) { + err = opal_info_set(object->s_info, key_str->string, updated_value); + } else { + err = opal_info_set_internal(object->s_info, key_str->string, value_str->string); } OBJ_RELEASE(value_str); OBJ_RELEASE(key_str);