diff --git a/ompi/mca/coll/han/coll_han.h b/ompi/mca/coll/han/coll_han.h index 9ec12cc8e6c..08abfe8970a 100644 --- a/ompi/mca/coll/han/coll_han.h +++ b/ompi/mca/coll/han/coll_han.h @@ -102,6 +102,7 @@ struct mca_coll_han_allreduce_args_s { int seg_count; int root_up_rank; int root_low_rank; + int root_reduce_low_rank; int num_segments; int cur_seg; int w_rank; diff --git a/ompi/mca/coll/han/coll_han_allreduce.c b/ompi/mca/coll/han/coll_han_allreduce.c index 039913d7fdb..0971a39f1b7 100644 --- a/ompi/mca/coll/han/coll_han_allreduce.c +++ b/ompi/mca/coll/han/coll_han_allreduce.c @@ -6,6 +6,8 @@ * * Copyright (c) 2020 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) 2023 Computer Architecture and VLSI Systems (CARV) + * Laboratory, ICS Forth. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -22,6 +24,7 @@ #include "coll_han.h" #include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_base_util.h" #include "ompi/mca/coll/base/coll_tags.h" #include "ompi/mca/pml/pml.h" #include "coll_han_trigger.h" @@ -43,6 +46,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args, struct ompi_op_t *op, int root_up_rank, int root_low_rank, + int root_reduce_low_rank, struct ompi_communicator_t *up_comm, struct ompi_communicator_t *low_comm, int num_segments, @@ -59,6 +63,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args, args->op = op; args->root_up_rank = root_up_rank; args->root_low_rank = root_low_rank; + args->root_reduce_low_rank = root_reduce_low_rank; args->up_comm = up_comm; args->low_comm = low_comm; args->num_segments = num_segments; @@ -139,6 +144,17 @@ mca_coll_han_allreduce_intra(const void *sbuf, int low_rank = ompi_comm_rank(low_comm); int root_up_rank = 0; int root_low_rank = 0; + int root_reduce_low_rank = 0; + + mca_coll_base_avail_coll_t *low_1st_module = (mca_coll_base_avail_coll_t *) + opal_list_get_last(low_comm->c_coll->module_list); + + // Invoke XHC's "special" Reduce + if(0 == strcmp(low_1st_module->ac_component_name, "xhc") + && low_comm->c_coll->coll_reduce_module == low_1st_module->ac_module) { + root_reduce_low_rank = -1; + } + /* Create t0 task for the first segment */ mca_coll_task_t *t0 = OBJ_NEW(mca_coll_task_t); /* Setup up t0 task arguments */ @@ -146,8 +162,8 @@ mca_coll_han_allreduce_intra(const void *sbuf, completed[0] = 0; mca_coll_han_allreduce_args_t *t = malloc(sizeof(mca_coll_han_allreduce_args_t)); mca_coll_han_set_allreduce_args(t, t0, (char *) sbuf, (char *) rbuf, seg_count, dtype, op, - root_up_rank, root_low_rank, up_comm, low_comm, num_segments, 0, - w_rank, count - (num_segments - 1) * seg_count, + root_up_rank, root_low_rank, root_reduce_low_rank, up_comm, + low_comm, num_segments, 0, w_rank, count - (num_segments - 1) * seg_count, low_rank != root_low_rank, NULL, completed); /* Init t0 task */ init_task(t0, mca_coll_han_allreduce_t0_task, (void *) (t)); @@ -215,18 +231,18 @@ int mca_coll_han_allreduce_t0_task(void *task_args) if (MPI_IN_PLACE == t->sbuf) { if (!t->noop) { t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE, (char *) t->rbuf, t->seg_count, t->dtype, - t->op, t->root_low_rank, t->low_comm, + t->op, t->root_reduce_low_rank, t->low_comm, t->low_comm->c_coll->coll_reduce_module); } else { t->low_comm->c_coll->coll_reduce((char *) t->rbuf, NULL, t->seg_count, t->dtype, - t->op, t->root_low_rank, t->low_comm, + t->op, t->root_reduce_low_rank, t->low_comm, t->low_comm->c_coll->coll_reduce_module); } } else { t->low_comm->c_coll->coll_reduce((char *) t->sbuf, (char *) t->rbuf, t->seg_count, t->dtype, - t->op, t->root_low_rank, t->low_comm, + t->op, t->root_reduce_low_rank, t->low_comm, t->low_comm->c_coll->coll_reduce_module); } return OMPI_SUCCESS; @@ -267,21 +283,20 @@ int mca_coll_han_allreduce_t1_task(void *task_args) if (!t->noop) { t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE, (char *) t->rbuf + extent * t->seg_count, tmp_count, - t->dtype, t->op, t->root_low_rank, t->low_comm, + t->dtype, t->op, t->root_reduce_low_rank, t->low_comm, t->low_comm->c_coll->coll_reduce_module); } else { t->low_comm->c_coll->coll_reduce((char *) t->rbuf + extent * t->seg_count, NULL, tmp_count, - t->dtype, t->op, t->root_low_rank, t->low_comm, + t->dtype, t->op, t->root_reduce_low_rank, t->low_comm, t->low_comm->c_coll->coll_reduce_module); - } } else { t->low_comm->c_coll->coll_reduce((char *) t->sbuf + extent * t->seg_count, (char *) t->rbuf + extent * t->seg_count, tmp_count, - t->dtype, t->op, t->root_low_rank, t->low_comm, + t->dtype, t->op, t->root_reduce_low_rank, t->low_comm, t->low_comm->c_coll->coll_reduce_module); - } + } } if (!t->noop) { ompi_request_wait(&ireduce_req, MPI_STATUS_IGNORE); @@ -337,25 +352,25 @@ int mca_coll_han_allreduce_t2_task(void *task_args) tmp_count = t->last_seg_count; } - if (t->sbuf == MPI_IN_PLACE) { - if (!t->noop) { - t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE, + if (t->sbuf == MPI_IN_PLACE) { + if (!t->noop) { + t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE, + (char *) t->rbuf + 2 * extent * t->seg_count, tmp_count, + t->dtype, t->op, t->root_reduce_low_rank, t->low_comm, + t->low_comm->c_coll->coll_reduce_module); + } else { + t->low_comm->c_coll->coll_reduce((char *) t->rbuf + 2 * extent * t->seg_count, + NULL, tmp_count, + t->dtype, t->op, t->root_reduce_low_rank, t->low_comm, + t->low_comm->c_coll->coll_reduce_module); + + } + } else { + t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 2 * extent * t->seg_count, (char *) t->rbuf + 2 * extent * t->seg_count, tmp_count, - t->dtype, t->op, t->root_low_rank, t->low_comm, - t->low_comm->c_coll->coll_reduce_module); - } else { - t->low_comm->c_coll->coll_reduce((char *) t->rbuf + 2 * extent * t->seg_count, - NULL, tmp_count, - t->dtype, t->op, t->root_low_rank, t->low_comm, + t->dtype, t->op, t->root_reduce_low_rank, t->low_comm, t->low_comm->c_coll->coll_reduce_module); - - } - } else { - t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 2 * extent * t->seg_count, - (char *) t->rbuf + 2 * extent * t->seg_count, tmp_count, - t->dtype, t->op, t->root_low_rank, t->low_comm, - t->low_comm->c_coll->coll_reduce_module); - } + } } if (!t->noop && req_count > 0) { ompi_request_wait_all(req_count, reqs, MPI_STATUSES_IGNORE); @@ -421,18 +436,18 @@ int mca_coll_han_allreduce_t3_task(void *task_args) if (!t->noop) { t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE, (char *) t->rbuf + 3 * extent * t->seg_count, tmp_count, - t->dtype, t->op, t->root_low_rank, t->low_comm, + t->dtype, t->op, t->root_reduce_low_rank, t->low_comm, t->low_comm->c_coll->coll_reduce_module); - } else { + } else { t->low_comm->c_coll->coll_reduce((char *) t->rbuf + 3 * extent * t->seg_count, NULL, tmp_count, - t->dtype, t->op, t->root_low_rank, t->low_comm, + t->dtype, t->op, t->root_reduce_low_rank, t->low_comm, t->low_comm->c_coll->coll_reduce_module); } } else { t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 3 * extent * t->seg_count, (char *) t->rbuf + 3 * extent * t->seg_count, tmp_count, - t->dtype, t->op, t->root_low_rank, t->low_comm, + t->dtype, t->op, t->root_reduce_low_rank, t->low_comm, t->low_comm->c_coll->coll_reduce_module); } } @@ -473,6 +488,7 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf, ompi_communicator_t *low_comm; ompi_communicator_t *up_comm; int root_low_rank = 0; + int root_reduce_low_rank = 0; int low_rank; int ret; mca_coll_han_module_t *han_module = (mca_coll_han_module_t *)module; @@ -504,22 +520,31 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf, up_comm = han_module->sub_comm[INTER_NODE]; low_rank = ompi_comm_rank(low_comm); + mca_coll_base_avail_coll_t *low_1st_module = (mca_coll_base_avail_coll_t *) + opal_list_get_last(low_comm->c_coll->module_list); + + // Invoke XHC's "special" Reduce + if(0 == strcmp(low_1st_module->ac_component_name, "xhc") + && low_comm->c_coll->coll_reduce_module == low_1st_module->ac_module) { + root_reduce_low_rank = -1; + } + /* Low_comm reduce */ if (MPI_IN_PLACE == sbuf) { if (low_rank == root_low_rank) { ret = low_comm->c_coll->coll_reduce(MPI_IN_PLACE, (char *)rbuf, - count, dtype, op, root_low_rank, + count, dtype, op, root_reduce_low_rank, low_comm, low_comm->c_coll->coll_reduce_module); } else { ret = low_comm->c_coll->coll_reduce((char *)rbuf, NULL, - count, dtype, op, root_low_rank, + count, dtype, op, root_reduce_low_rank, low_comm, low_comm->c_coll->coll_reduce_module); } } else { ret = low_comm->c_coll->coll_reduce((char *)sbuf, (char *)rbuf, - count, dtype, op, root_low_rank, + count, dtype, op, root_reduce_low_rank, low_comm, low_comm->c_coll->coll_reduce_module); } if (OPAL_UNLIKELY(OMPI_SUCCESS != ret)) {